mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
5 Commits
feat/robot
...
envs/suppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ec9392bcb | ||
|
|
84b34ae75c | ||
|
|
ff267c772b | ||
|
|
652b1b854d | ||
|
|
8831b3c47b |
2
.github/workflows/full_tests.yml
vendored
2
.github/workflows/full_tests.yml
vendored
@@ -173,8 +173,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
25
AI_POLICY.md
25
AI_POLICY.md
@@ -1,25 +0,0 @@
|
||||
# AI Usage Policy
|
||||
|
||||
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
|
||||
|
||||
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
|
||||
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
|
||||
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
|
||||
|
||||
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
|
||||
|
||||
## Remember the Human Maintainers
|
||||
|
||||
Please remember that LeRobot is maintained by a dedicated team of humans.
|
||||
|
||||
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
|
||||
|
||||
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
|
||||
|
||||
## AI is Welcome Here
|
||||
|
||||
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
|
||||
|
||||
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
|
||||
|
||||
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -55,7 +55,8 @@ To make your environment loadable from the Hub, your repository must contain at
|
||||
|
||||
**`env.py`** (or custom Python file)
|
||||
|
||||
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
|
||||
- Must expose a `make_env(n_envs: int, use_async_envs: bool, **kwargs)` function
|
||||
- The function should accept `**kwargs` to allow users to pass custom configurations
|
||||
- This function should return one of:
|
||||
- A `gym.vector.VectorEnv` (most common)
|
||||
- A single `gym.Env` (will be automatically wrapped)
|
||||
@@ -99,6 +100,8 @@ Create an `env.py` file with a `make_env` function:
|
||||
```python
|
||||
# env.py
|
||||
import gymnasium as gym
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
def make_env(n_envs: int = 1, use_async_envs: bool = False):
|
||||
"""
|
||||
@@ -250,6 +253,76 @@ envs_dict = make_env(
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Configuration via kwargs
|
||||
|
||||
Hub environments can accept custom configurations through keyword arguments. This is useful for parameterizing tasks, loading different objects, or overriding default settings:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
# Pass a config file path
|
||||
envs_dict = make_env(
|
||||
"nvkartik/isaaclab-arena-envs:envs/microwave_g1.py",
|
||||
n_envs=4,
|
||||
trust_remote_code=True,
|
||||
config_path=Path("/path/to/my_config.yaml"),
|
||||
)
|
||||
|
||||
# Pass config overrides as a dictionary
|
||||
envs_dict = make_env(
|
||||
"nvkartik/isaaclab-arena-envs:envs/microwave_g1.py",
|
||||
n_envs=4,
|
||||
trust_remote_code=True,
|
||||
config_overrides={
|
||||
"scene.object": "microwave",
|
||||
"sim.dt": 0.01,
|
||||
},
|
||||
)
|
||||
|
||||
# Combine config path with overrides
|
||||
envs_dict = make_env(
|
||||
"username/my-env",
|
||||
n_envs=4,
|
||||
trust_remote_code=True,
|
||||
config_path="configs/gr1_pick_place.yaml",
|
||||
config_overrides={"scene.table_objects": ["apple", "banana", "cup"]},
|
||||
)
|
||||
```
|
||||
|
||||
Any keyword arguments you pass will be forwarded to the hub environment's `make_env` function. Check the environment's documentation for supported configuration options.
|
||||
|
||||
### Using Custom kwargs with lerobot-eval
|
||||
|
||||
When evaluating policies using the `lerobot-eval` CLI, you can pass custom kwargs to hub environments using the `--env_kwargs.` prefix:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=user123/example-policy-checkpoint \
|
||||
--env=user123/example-sim-backend \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--env_kwargs.task_id=demo_task_alpha \
|
||||
--env_kwargs.agent_profile=arm_v2 \
|
||||
--env_kwargs.target_item=object_red \
|
||||
--env_kwargs.run_mode=offscreen \
|
||||
--env_kwargs.enable_sensors=true \
|
||||
--env_kwargs.record_output=true \
|
||||
--env_kwargs.output_horizon=10 \
|
||||
--env_kwargs.output_stride=15 \
|
||||
--env_kwargs.state_features=joint_angles \
|
||||
--env_kwargs.visual_streams=front_camera
|
||||
```
|
||||
|
||||
All `--env_kwargs.*` arguments will be collected into a dictionary and passed as keyword arguments to the hub environment's `make_env` function. This allows you to:
|
||||
|
||||
- Pass configuration file paths
|
||||
- Override default settings
|
||||
- Specify custom task parameters
|
||||
- Control simulation options (headless mode, camera settings, etc.)
|
||||
- Select different embodiments or objects
|
||||
|
||||
The hub environment's `make_env` function receives these as regular keyword arguments, so check the environment's documentation for the available options.
|
||||
|
||||
## URL Format Reference
|
||||
|
||||
The hub URL format supports several patterns:
|
||||
@@ -266,7 +339,7 @@ The hub URL format supports several patterns:
|
||||
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
|
||||
|
||||
```python
|
||||
def make_env(n_envs: int = 1, use_async_envs: bool = False):
|
||||
def make_env(n_envs: int = 1, use_async_envs: bool = False, **kwargs):
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
# Return dict: {suite_name: {task_id: VectorEnv}}
|
||||
@@ -388,8 +461,9 @@ pip install gymnasium numpy
|
||||
Your `env.py` must expose a `make_env` function:
|
||||
|
||||
```python
|
||||
def make_env(n_envs: int, use_async_envs: bool):
|
||||
def make_env(n_envs: int, use_async_envs: bool, **kwargs):
|
||||
# Your implementation
|
||||
# kwargs can include config_path, config_overrides, etc.
|
||||
pass
|
||||
```
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -18,6 +18,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
@@ -70,6 +71,9 @@ def main():
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
@@ -95,6 +99,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -109,6 +116,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
@@ -45,6 +46,9 @@ def main():
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
@@ -89,6 +93,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -104,6 +111,9 @@ def main():
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -17,16 +17,30 @@
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import build_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -37,10 +51,6 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
@@ -54,31 +64,68 @@ def main():
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Attach FK/IK pipelines so the robot works in EE space
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
robot.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
robot.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create the dataset — obs auto-derived from FK pipeline, EE action spec explicit
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(
|
||||
robot,
|
||||
use_videos=True,
|
||||
action_features={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
},
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
@@ -104,18 +151,21 @@ def main():
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop — pipelines applied internally by robot
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -130,6 +180,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -16,17 +16,21 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import make_so10x_fk_observation_pipeline
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
@@ -35,7 +39,6 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import build_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -46,10 +49,6 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
@@ -66,59 +65,77 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path=URDF_PATH,
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Phone output pipeline: map raw phone gesture to EE delta (no robot obs needed)
|
||||
phone.set_output_pipeline(
|
||||
RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[MapPhoneActionToRobotAction(platform=teleop_config.phone_os)],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Robot FK observation pipeline: joints → EE pose
|
||||
robot.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
|
||||
# Robot input pipeline: EE delta + current robot obs → joint commands
|
||||
robot.set_input_pipeline(
|
||||
RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=motor_names,
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Dataset features auto-derived from robot's FK obs pipeline and phone's mapped action pipeline
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(robot, phone, use_videos=True),
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
@@ -141,7 +158,7 @@ def main():
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop — pipelines applied internally by robot and phone
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -151,6 +168,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -166,6 +186,9 @@ def main():
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -87,8 +87,8 @@ from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
_make_identity_observation_pipeline as make_default_robot_observation_processor,
|
||||
_make_identity_robot_action_pipeline as make_default_robot_action_processor,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
|
||||
@@ -17,16 +17,30 @@
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import build_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -37,10 +51,6 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
@@ -54,31 +64,68 @@ def main():
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Attach FK/IK pipelines so the robot works in EE space
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
robot.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
robot.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create the dataset — obs auto-derived from FK pipeline, EE action spec explicit
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(
|
||||
robot,
|
||||
use_videos=True,
|
||||
action_features={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
},
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
@@ -88,7 +135,7 @@ def main():
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
@@ -104,18 +151,21 @@ def main():
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop — pipelines applied internally by robot
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -130,6 +180,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -17,20 +17,25 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.pipelines import make_so10x_leader_fk_pipeline
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import (
|
||||
build_dataset_features,
|
||||
check_action_space_compatibility,
|
||||
check_observation_space_compatibility,
|
||||
)
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -41,10 +46,6 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: Use the URDF from the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
@@ -61,17 +62,77 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# Attach EE-space pipelines to the objects
|
||||
motor_names = list(follower.bus.motors.keys())
|
||||
follower.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
follower.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
leader.set_output_pipeline(make_so10x_leader_fk_pipeline(URDF_PATH, list(leader.bus.motors.keys())))
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Dataset features are derived automatically from robot/teleop pipelines
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(follower, leader, use_videos=True),
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
@@ -81,13 +142,9 @@ def main():
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Verify action/observation space alignment (warns on mismatch)
|
||||
check_action_space_compatibility(leader, follower)
|
||||
check_observation_space_compatibility(follower, leader)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_ee")
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
@@ -98,8 +155,7 @@ def main():
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Pipelines applied automatically inside robot.get_observation(),
|
||||
# teleop.get_action(), and robot.send_action()
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
@@ -109,6 +165,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -124,6 +183,9 @@ def main():
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -14,23 +14,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_teleoperate import teleop_loop
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.pipelines import make_so10x_leader_fk_pipeline
|
||||
from lerobot.utils.pipeline_utils import check_action_space_compatibility
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# NOTE: Use the URDF from the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize the robot and teleoperator config
|
||||
@@ -43,14 +47,47 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# Attach EE-space pipelines to the objects
|
||||
motor_names = list(follower.bus.motors.keys())
|
||||
follower.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
follower.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
leader.set_output_pipeline(make_so10x_leader_fk_pipeline(URDF_PATH, list(leader.bus.motors.keys())))
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Verify action space alignment (warns if leader EE ≠ follower action_features)
|
||||
check_action_space_compatibility(leader, follower)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
@@ -60,12 +97,28 @@ def main():
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
try:
|
||||
# Pipelines applied automatically inside teleop.get_action() and robot.send_action()
|
||||
teleop_loop(teleop=leader, robot=follower, fps=FPS, display_data=True)
|
||||
finally:
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.5"
|
||||
version = "0.4.4"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -214,9 +214,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
|
||||
@@ -49,18 +49,23 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import (
|
||||
RobotConfig, # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
@@ -480,9 +485,8 @@ class RobotClient:
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
@@ -508,5 +512,4 @@ def async_client(cfg: RobotClientConfig):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
|
||||
@@ -27,7 +27,7 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -38,6 +38,8 @@ class EvalPipelineConfig:
|
||||
seed: int | None = 1000
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
# Additional kwargs to pass to hub environments (e.g., config_path, config_overrides, custom params)
|
||||
env_kwargs: dict = field(default_factory=dict)
|
||||
# Explicit consent to execute remote code from the Hub (required for hub environments).
|
||||
trust_remote_code: bool = False
|
||||
|
||||
|
||||
@@ -7,13 +7,6 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -567,22 +567,20 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -624,10 +622,9 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -637,20 +634,21 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(frame_ranges):
|
||||
if range_idx >= len(time_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -663,7 +661,6 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -752,17 +749,15 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
|
||||
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for the README).
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||
HF_LEROBOT_HOME environment variable).
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -1771,12 +1771,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
@@ -12,14 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
@@ -41,3 +41,99 @@ def create_initial_features(
|
||||
if observation:
|
||||
features[PipelineFeatureType.OBSERVATION] = observation
|
||||
return features
|
||||
|
||||
|
||||
# Helper to filter state/action keys based on regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[str]) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
|
||||
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
|
||||
for prefix in prefixes_to_strip:
|
||||
if key.startswith(prefix):
|
||||
return key[len(prefix) :]
|
||||
return key
|
||||
|
||||
|
||||
# Define prefixes to strip from feature keys for clean names.
|
||||
# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms.
|
||||
PREFIXES_TO_STRIP = tuple(
|
||||
f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1])
|
||||
)
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: DataProcessorPipeline,
|
||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||
|
||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||
formats them for use with a Hugging Face LeRobot Dataset.
|
||||
|
||||
Args:
|
||||
pipeline: The DataProcessorPipeline to apply.
|
||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||
use_videos: If False, image features are excluded.
|
||||
patterns: A sequence of regex patterns to filter action and state features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
Returns:
|
||||
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
|
||||
"""
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
ACTION: {},
|
||||
OBS_STR: {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
# Iterate through all features transformed by the pipeline.
|
||||
for ptype, feats in all_features.items():
|
||||
if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]:
|
||||
continue
|
||||
|
||||
for key, value in feats.items():
|
||||
# 1. Categorize the feature.
|
||||
is_action = ptype == PipelineFeatureType.ACTION
|
||||
# Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3.
|
||||
# All other observations are treated as state.
|
||||
is_image = not is_action and (
|
||||
(isinstance(value, tuple) and len(value) == 3)
|
||||
or (
|
||||
key.startswith(f"{OBS_IMAGES}.")
|
||||
or key.startswith(f"{images_token}.")
|
||||
or f".{images_token}." in key
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, patterns):
|
||||
continue
|
||||
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features[ACTION][name] = value
|
||||
else:
|
||||
processed_features[OBS_STR][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
return dataset_features
|
||||
|
||||
@@ -227,17 +227,16 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -249,11 +248,7 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -358,16 +353,15 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
|
||||
@@ -105,6 +105,7 @@ def make_env(
|
||||
use_async_envs: bool = False,
|
||||
hub_cache_dir: str | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Makes a gym vector environment according to the config or Hub reference.
|
||||
|
||||
@@ -118,6 +119,9 @@ def make_env(
|
||||
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
|
||||
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
|
||||
Default False — must be set to True to import/exec hub `env.py`.
|
||||
**kwargs: Additional keyword arguments passed to the hub environment's `make_env` function.
|
||||
Useful for passing custom configurations like `config_path`, `config_overrides`, etc.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
@@ -149,9 +153,11 @@ def make_env(
|
||||
# import and surface clear import errors
|
||||
module = _import_hub_module(local_file, repo_id)
|
||||
|
||||
# call the hub-provided make_env
|
||||
# call the hub-provided make_env with any additional kwargs
|
||||
env_cfg = None if isinstance(cfg, str) else cfg
|
||||
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg)
|
||||
raw_result = _call_make_env(
|
||||
module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg, **kwargs
|
||||
)
|
||||
|
||||
# normalize the return into {suite: {task_id: vec_env}}
|
||||
return _normalize_hub_result(raw_result)
|
||||
|
||||
@@ -311,20 +311,27 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any:
|
||||
return module
|
||||
|
||||
|
||||
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None) -> Any:
|
||||
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None, **kwargs) -> Any:
|
||||
"""
|
||||
Ensure module exposes make_env and call it.
|
||||
Ensure module exposes make_env and call it with any additional kwargs.
|
||||
|
||||
Args:
|
||||
module: The imported hub module containing make_env.
|
||||
n_envs: Number of parallel environments.
|
||||
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv.
|
||||
**kwargs: Additional keyword arguments to pass to the hub's make_env function.
|
||||
Common examples include config_path, config_overrides, etc.
|
||||
"""
|
||||
if not hasattr(module, "make_env"):
|
||||
raise AttributeError(
|
||||
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
|
||||
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool, **kwargs)`."
|
||||
)
|
||||
entry_fn = module.make_env
|
||||
# Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID)
|
||||
if cfg is not None:
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg)
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg, **kwargs)
|
||||
else:
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, **kwargs)
|
||||
|
||||
|
||||
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
|
||||
@@ -55,16 +55,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||
backbone. If None, no resizing is done and the original image resolution is used.
|
||||
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||
this is computed automatically. Can also be set directly for legacy configs that use
|
||||
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||
crop in eval mode).
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
@@ -120,9 +114,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
@@ -147,10 +139,6 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
@@ -183,25 +171,6 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
if self.resize_shape is not None and (
|
||||
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||
):
|
||||
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||
if not (0 < self.crop_ratio <= 1.0):
|
||||
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||
|
||||
if self.resize_shape is not None:
|
||||
if self.crop_ratio < 1.0:
|
||||
self.crop_shape = (
|
||||
int(self.resize_shape[0] * self.crop_ratio),
|
||||
int(self.resize_shape[1] * self.crop_ratio),
|
||||
)
|
||||
else:
|
||||
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||
self.crop_shape = None
|
||||
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
@@ -229,12 +198,13 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.resize_shape is None and self.crop_shape is not None:
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
|
||||
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
@@ -454,18 +446,12 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.resize_shape is not None:
|
||||
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||
else:
|
||||
self.resize = None
|
||||
|
||||
crop_shape = config.crop_shape
|
||||
if crop_shape is not None:
|
||||
if config.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -491,16 +477,13 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
if config.crop_shape is not None:
|
||||
dummy_shape_h_w = config.crop_shape
|
||||
elif config.resize_shape is not None:
|
||||
dummy_shape_h_w = config.resize_shape
|
||||
else:
|
||||
dummy_shape_h_w = images_shape[1:]
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
@@ -516,10 +499,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: resize if configured, then crop if configured.
|
||||
|
||||
if self.resize is not None:
|
||||
x = self.resize(x)
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -106,9 +106,6 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -593,12 +593,6 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
@@ -30,6 +30,12 @@ from .core import (
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
from .gym_action_processor import (
|
||||
Numpy2TorchActionProcessorStep,
|
||||
Torch2NumpyActionProcessorStep,
|
||||
@@ -89,7 +95,11 @@ __all__ = [
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"make_default_processors",
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from .converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
@@ -25,44 +24,39 @@ from .core import RobotAction, RobotObservation
|
||||
from .pipeline import IdentityProcessorStep, RobotProcessorPipeline
|
||||
|
||||
|
||||
# ── Internal identity pipeline helpers (used by Robot/Teleoperator base classes) ──────────────────
|
||||
|
||||
|
||||
def _make_identity_observation_pipeline() -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
|
||||
"""Identity pipeline for robot observations (get_observation output pipeline)."""
|
||||
return RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
def _make_identity_robot_action_pipeline() -> RobotProcessorPipeline[
|
||||
def make_default_teleop_action_processor() -> RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
]:
|
||||
"""Identity pipeline for robot action input (send_action input pipeline, takes (action, obs) tuple)."""
|
||||
return RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
teleop_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
return teleop_action_processor
|
||||
|
||||
|
||||
def _make_identity_teleop_action_pipeline() -> RobotProcessorPipeline[RobotAction, RobotAction]:
|
||||
"""Identity pipeline for teleop action output (get_action output pipeline, takes just action)."""
|
||||
return RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
def make_default_robot_action_processor() -> RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
]:
|
||||
robot_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
return robot_action_processor
|
||||
|
||||
|
||||
def _make_identity_feedback_pipeline() -> RobotProcessorPipeline[dict, dict]:
|
||||
"""Identity pipeline for teleop feedback input (send_feedback input pipeline)."""
|
||||
return RobotProcessorPipeline[dict, dict](
|
||||
def make_default_robot_observation_processor() -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
|
||||
robot_observation_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return robot_observation_processor
|
||||
|
||||
|
||||
def make_default_processors():
|
||||
teleop_action_processor = make_default_teleop_action_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
return (teleop_action_processor, robot_action_processor, robot_observation_processor)
|
||||
|
||||
@@ -19,17 +19,15 @@ from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
|
||||
|
||||
@@ -43,9 +43,12 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
from .core import EnvTransition, RobotObservation, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Type-checking only import — do NOT import transformers at module level (it loads TF which blocks)
|
||||
if TYPE_CHECKING:
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
else:
|
||||
AutoProcessor = None
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -103,7 +106,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
# Use provided tokenizer object directly
|
||||
self.input_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
from transformers import AutoTokenizer # lazy import to avoid TF deadlock at module load
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -366,12 +370,12 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer # lazy import to avoid TF deadlock at module load
|
||||
|
||||
if self.action_tokenizer_input_object is not None:
|
||||
self.action_tokenizer = self.action_tokenizer_input_object
|
||||
|
||||
elif self.action_tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.action_tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
|
||||
@@ -102,11 +102,11 @@ class BiOpenArmFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -136,7 +136,7 @@ class BiOpenArmFollower(Robot):
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -150,7 +150,7 @@ class BiOpenArmFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
|
||||
@@ -86,11 +86,11 @@ class BiSOFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -119,7 +119,7 @@ class BiSOFollower(Robot):
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -133,7 +133,7 @@ class BiSOFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
|
||||
@@ -147,7 +147,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Define the observation space for dataset recording.
|
||||
|
||||
Returns:
|
||||
@@ -184,7 +184,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Define the action space.
|
||||
|
||||
Returns:
|
||||
@@ -198,7 +198,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Returns:
|
||||
@@ -255,7 +255,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
return observation
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -71,11 +71,11 @@ class HopeJrArm(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -128,7 +128,7 @@ class HopeJrArm(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
@@ -147,7 +147,7 @@ class HopeJrArm(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
|
||||
@@ -107,11 +107,11 @@ class HopeJrHand(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -158,7 +158,7 @@ class HopeJrHand(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
@@ -178,7 +178,7 @@ class HopeJrHand(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
@@ -73,11 +73,11 @@ class KochFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -182,7 +182,7 @@ class KochFollower(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -200,7 +200,7 @@ class KochFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -98,11 +98,11 @@ class LeKiwi(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._state_ft
|
||||
|
||||
@property
|
||||
@@ -338,7 +338,7 @@ class LeKiwi(Robot):
|
||||
} # m/s and deg/s
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read actuators position for arm and vel for base
|
||||
start = time.perf_counter()
|
||||
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||
@@ -367,7 +367,7 @@ class LeKiwi(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -98,11 +98,11 @@ class LeKiwiClient(Robot):
|
||||
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._state_ft
|
||||
|
||||
@property
|
||||
@@ -250,7 +250,7 @@ class LeKiwiClient(Robot):
|
||||
return new_frames, new_state
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
@@ -304,7 +304,7 @@ class LeKiwiClient(Robot):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
Args:
|
||||
|
||||
@@ -73,11 +73,11 @@ class OmxFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -165,7 +165,7 @@ class OmxFollower(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -183,7 +183,7 @@ class OmxFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -105,12 +105,12 @@ class OpenArmFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Combined observation features from motors and cameras."""
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Action features."""
|
||||
return self._motors_ft
|
||||
|
||||
@@ -219,7 +219,7 @@ class OpenArmFollower(Robot):
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
|
||||
@@ -251,7 +251,7 @@ class OpenArmFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
|
||||
@@ -95,11 +95,11 @@ class Reachy2Robot(Robot):
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def raw_observation_features(self) -> dict[str, Any]:
|
||||
def observation_features(self) -> dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
@@ -170,7 +170,7 @@ class Reachy2Robot(Robot):
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
@@ -184,7 +184,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
@@ -18,11 +18,8 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.processor.core import RobotAction, RobotObservation
|
||||
from lerobot.processor.factory import _make_identity_observation_pipeline, _make_identity_robot_action_pipeline
|
||||
from lerobot.processor.pipeline import RobotProcessorPipeline
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS
|
||||
|
||||
from .config import RobotConfig
|
||||
@@ -37,10 +34,6 @@ class Robot(abc.ABC):
|
||||
This class provides a standardized interface for interacting with physical robots.
|
||||
Subclasses must implement all abstract methods and properties to be usable.
|
||||
|
||||
Pipelines are first-class citizens: every robot carries an optional output pipeline
|
||||
(applied in get_observation()) and an optional input pipeline (applied in send_action()).
|
||||
Both default to identity (no-op), so existing robots work without any changes.
|
||||
|
||||
Attributes:
|
||||
config_class (RobotConfig): The expected configuration class for this robot.
|
||||
name (str): The unique robot name used to identify this robot type.
|
||||
@@ -62,12 +55,6 @@ class Robot(abc.ABC):
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
# Pipeline interface — default to identity (no-op), swap via set_output/input_pipeline()
|
||||
self._output_pipeline: RobotProcessorPipeline = _make_identity_observation_pipeline()
|
||||
self._input_pipeline: RobotProcessorPipeline = _make_identity_robot_action_pipeline()
|
||||
# Cache of most recent raw observation; used by input_pipeline for IK initial guess
|
||||
self._last_raw_obs: RobotObservation = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
@@ -97,117 +84,40 @@ class Robot(abc.ABC):
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
# ── Pipeline interface ────────────────────────────────────────────────────
|
||||
|
||||
def output_pipeline(self) -> RobotProcessorPipeline:
|
||||
"""
|
||||
Pipeline applied inside get_observation() to transform raw hardware observations.
|
||||
Default: identity (no-op). Override via set_output_pipeline() or subclassing.
|
||||
|
||||
Example: set a forward-kinematics pipeline to convert joint positions to EE pose.
|
||||
"""
|
||||
return self._output_pipeline
|
||||
|
||||
def input_pipeline(self) -> RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]:
|
||||
"""
|
||||
Pipeline applied inside send_action() to transform incoming actions before hardware write.
|
||||
Default: identity (no-op). Override via set_input_pipeline() or subclassing.
|
||||
|
||||
The pipeline receives a (action, last_raw_obs) tuple so IK solvers can use the
|
||||
current joint configuration as an initial guess.
|
||||
|
||||
Example: set an inverse-kinematics pipeline to convert EE commands to joint positions.
|
||||
"""
|
||||
return self._input_pipeline
|
||||
|
||||
def set_output_pipeline(self, pipeline: RobotProcessorPipeline) -> None:
|
||||
"""Set the observation output pipeline (applied in get_observation())."""
|
||||
self._output_pipeline = pipeline
|
||||
|
||||
def set_input_pipeline(self, pipeline: RobotProcessorPipeline) -> None:
|
||||
"""Set the action input pipeline (applied in send_action())."""
|
||||
self._input_pipeline = pipeline
|
||||
|
||||
# ── Feature properties ────────────────────────────────────────────────────
|
||||
|
||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_features(self) -> dict:
|
||||
"""
|
||||
Pipeline-transformed observation features.
|
||||
A dictionary describing the structure and types of the observations produced by the robot.
|
||||
Its structure (keys) should match the structure of what is returned by :pymeth:`get_observation`.
|
||||
Values for the dict should either be:
|
||||
- The type of the value if it's a simple value, e.g. `float` for single proprioceptive value (a joint's position/velocity)
|
||||
- A tuple representing the shape if it's an array-type value, e.g. `(height, width, channel)` for images
|
||||
|
||||
Applies output_pipeline().transform_features() to raw_observation_features so the
|
||||
returned dict matches what get_observation() actually returns to callers.
|
||||
|
||||
Use raw_observation_features to inspect hardware-level feature shapes.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
"""
|
||||
from lerobot.datasets.pipeline_features import create_initial_features # lazy import
|
||||
|
||||
initial = create_initial_features(observation=self.raw_observation_features)
|
||||
transformed = self.output_pipeline().transform_features(initial)
|
||||
return transformed.get(PipelineFeatureType.OBSERVATION, {})
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def raw_observation_features(self) -> dict:
|
||||
"""
|
||||
Hardware-level observation features (before any pipeline transformation).
|
||||
|
||||
A dictionary describing the structure and types of the observations produced
|
||||
directly by the robot hardware. Its structure (keys) should match the structure
|
||||
of what is returned by :pymeth:`_get_observation`. Values should be:
|
||||
- The type if it's a simple value, e.g. ``float`` for joint position
|
||||
- A tuple representing the shape for array values, e.g. ``(H, W, C)`` for images
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
Note: this property should be able to be called regardless of whether the robot is connected or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def raw_action_features(self) -> dict:
|
||||
"""
|
||||
Hardware-level action features (before any pipeline transformation).
|
||||
|
||||
A dictionary describing the structure and types of the actions accepted directly
|
||||
by the robot hardware (i.e. what :pymeth:`_send_action` receives). Its structure
|
||||
(keys) should match the structure of what is expected by :pymeth:`_send_action`.
|
||||
Values should be the type of the value if it's a simple value, e.g. ``float`` for
|
||||
single proprioceptive value (a joint's goal position/velocity).
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
"""
|
||||
Pipeline-transformed action features.
|
||||
A dictionary describing the structure and types of the actions expected by the robot. Its structure
|
||||
(keys) should match the structure of what is passed to :pymeth:`send_action`. Values for the dict
|
||||
should be the type of the value if it's a simple value, e.g. `float` for single proprioceptive value
|
||||
(a joint's goal position/velocity)
|
||||
|
||||
Applies input_pipeline().transform_features() to raw_action_features so the
|
||||
returned dict reflects what the input pipeline outputs to hardware.
|
||||
|
||||
Use raw_action_features to inspect hardware-level action feature shapes.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
Note: this property should be able to be called regardless of whether the robot is connected or not.
|
||||
"""
|
||||
from lerobot.datasets.pipeline_features import create_initial_features # lazy import
|
||||
|
||||
initial = create_initial_features(action=self.raw_action_features)
|
||||
transformed = self.input_pipeline().transform_features(initial)
|
||||
return transformed.get(PipelineFeatureType.ACTION, {})
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Whether the robot is currently connected or not. If ``False``, calling
|
||||
:pymeth:`get_observation` or :pymeth:`send_action` should raise an error.
|
||||
Whether the robot is currently connected or not. If `False`, calling :pymeth:`get_observation` or
|
||||
:pymeth:`send_action` should raise an error.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -225,7 +135,7 @@ class Robot(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Whether the robot is currently calibrated or not. Should be always ``True`` if not applicable"""
|
||||
"""Whether the robot is currently calibrated or not. Should be always `True` if not applicable"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -243,7 +153,7 @@ class Robot(abc.ABC):
|
||||
Helper to load calibration data from the specified file.
|
||||
|
||||
Args:
|
||||
fpath (Path | None): Optional path to the calibration file. Defaults to ``self.calibration_fpath``.
|
||||
fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`.
|
||||
"""
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath) as f, draccus.config_type("json"):
|
||||
@@ -254,7 +164,7 @@ class Robot(abc.ABC):
|
||||
Helper to save calibration data to the specified file.
|
||||
|
||||
Args:
|
||||
fpath (Path | None): Optional path to save the calibration file. Defaults to ``self.calibration_fpath``.
|
||||
fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`.
|
||||
"""
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath, "w") as f, draccus.config_type("json"):
|
||||
@@ -268,64 +178,30 @@ class Robot(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# ── Template methods (concrete, call pipeline internally) ─────────────────
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Retrieve the current observation from the robot and apply the output pipeline.
|
||||
|
||||
Calls :pymeth:`_get_observation` to get raw hardware data, caches it for use as
|
||||
IK initial guess in :pymeth:`send_action`, then applies :pymeth:`output_pipeline`.
|
||||
Retrieve the current observation from the robot.
|
||||
|
||||
Returns:
|
||||
RobotObservation: Pipeline-transformed observation. With the default identity
|
||||
pipeline this equals the raw observation from :pymeth:`_get_observation`.
|
||||
RobotObservation: A flat dictionary representing the robot's current sensory state. Its structure
|
||||
should match :pymeth:`observation_features`.
|
||||
"""
|
||||
raw = self._get_observation()
|
||||
self._last_raw_obs = raw
|
||||
return self.output_pipeline()(raw)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Retrieve the raw observation directly from robot hardware.
|
||||
|
||||
Returns:
|
||||
RobotObservation: A flat dictionary representing the robot's current sensory
|
||||
state. Its structure should match :pymeth:`raw_observation_features`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""
|
||||
Apply the input pipeline and send the resulting action to robot hardware.
|
||||
|
||||
The input pipeline receives ``(action, last_raw_obs)`` so IK solvers can use the
|
||||
cached joint configuration as an initial guess. With the default identity pipeline,
|
||||
the action is forwarded unchanged.
|
||||
Send an action command to the robot.
|
||||
|
||||
Args:
|
||||
action (RobotAction): Dictionary representing the desired action. Its structure
|
||||
should match :pymeth:`action_features`.
|
||||
action (RobotAction): Dictionary representing the desired action. Its structure should match
|
||||
:pymeth:`action_features`.
|
||||
|
||||
Returns:
|
||||
RobotAction: The action actually sent to the motors, potentially clipped or
|
||||
modified by the pipeline or hardware safety limits.
|
||||
"""
|
||||
transformed = self.input_pipeline()((action, self._last_raw_obs))
|
||||
return self._send_action(transformed)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""
|
||||
Send an action command directly to robot hardware.
|
||||
|
||||
Args:
|
||||
action (RobotAction): Dictionary of motor-level commands. Its structure should
|
||||
match what the hardware expects (typically motor positions/velocities).
|
||||
|
||||
Returns:
|
||||
RobotAction: The action actually sent, potentially clipped by safety limits.
|
||||
RobotAction: The action actually sent to the motors potentially clipped or modified, e.g. by
|
||||
safety limits on velocity.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .ee_space import make_so10x_fk_observation_pipeline, make_so10x_ik_action_pipeline
|
||||
|
||||
__all__ = ["make_so10x_fk_observation_pipeline", "make_so10x_ik_action_pipeline"]
|
||||
@@ -1,147 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
End-effector space pipelines for SO-100/101 follower robots.
|
||||
|
||||
These factory functions return ready-to-use pipelines that convert between joint space
|
||||
and Cartesian end-effector space. Attach them to a robot with ``set_output_pipeline`` /
|
||||
``set_input_pipeline`` to enable EE-space recording and teleoperation.
|
||||
|
||||
Example::
|
||||
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
)
|
||||
|
||||
motor_names = list(follower.bus.motors.keys())
|
||||
follower.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
follower.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
"""
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
|
||||
_DEFAULT_EE_BOUNDS = {"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}
|
||||
_DEFAULT_GRIPPER_FRAME = "gripper_frame_link"
|
||||
|
||||
|
||||
def make_so10x_fk_observation_pipeline(
|
||||
urdf_path: str,
|
||||
motor_names: list[str],
|
||||
*,
|
||||
target_frame_name: str = _DEFAULT_GRIPPER_FRAME,
|
||||
) -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
|
||||
"""
|
||||
Create a forward-kinematics observation pipeline for SO-100/101 follower robots.
|
||||
|
||||
Converts raw joint positions (observation) into end-effector pose (position + orientation).
|
||||
Attach this to a follower robot via ``set_output_pipeline`` so that ``get_observation()``
|
||||
returns EE coordinates instead of raw joint angles.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the SO-100/101 URDF file used for kinematics.
|
||||
motor_names: Ordered list of motor names matching the URDF joint names.
|
||||
target_frame_name: Name of the end-effector frame in the URDF.
|
||||
|
||||
Returns:
|
||||
A RobotProcessorPipeline that maps joint observations to EE observations.
|
||||
|
||||
Example::
|
||||
|
||||
follower.set_output_pipeline(
|
||||
make_so10x_fk_observation_pipeline("./so101.urdf", motor_names)
|
||||
)
|
||||
obs = follower.get_observation() # now contains ee.x, ee.y, ee.z, ...
|
||||
"""
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path=urdf_path,
|
||||
target_frame_name=target_frame_name,
|
||||
joint_names=motor_names,
|
||||
)
|
||||
return RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
def make_so10x_ik_action_pipeline(
|
||||
urdf_path: str,
|
||||
motor_names: list[str],
|
||||
*,
|
||||
target_frame_name: str = _DEFAULT_GRIPPER_FRAME,
|
||||
end_effector_bounds: dict | None = None,
|
||||
max_ee_step_m: float = 0.10,
|
||||
) -> RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]:
|
||||
"""
|
||||
Create an inverse-kinematics action pipeline for SO-100/101 follower robots.
|
||||
|
||||
Converts incoming end-effector pose commands into joint positions, applying safety
|
||||
bounds and step-size limits before solving IK. The current joint positions are used
|
||||
as the IK initial guess (taken from the cached ``_last_raw_obs``).
|
||||
|
||||
Attach this to a follower robot via ``set_input_pipeline`` so that ``send_action()``
|
||||
receives EE commands and translates them to motor positions before the hardware write.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the SO-100/101 URDF file used for kinematics.
|
||||
motor_names: Ordered list of motor names matching the URDF joint names.
|
||||
target_frame_name: Name of the end-effector frame in the URDF.
|
||||
end_effector_bounds: Dict with ``"min"`` and ``"max"`` lists (3D position bounds in metres).
|
||||
Defaults to ``{"min": [-1, -1, -1], "max": [1, 1, 1]}``.
|
||||
max_ee_step_m: Maximum allowed EE position change per step in metres.
|
||||
|
||||
Returns:
|
||||
A RobotProcessorPipeline that maps (EE action, raw obs) to joint action.
|
||||
|
||||
Example::
|
||||
|
||||
follower.set_input_pipeline(
|
||||
make_so10x_ik_action_pipeline("./so101.urdf", motor_names)
|
||||
)
|
||||
# send_action() now accepts ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_vel
|
||||
"""
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path=urdf_path,
|
||||
target_frame_name=target_frame_name,
|
||||
joint_names=motor_names,
|
||||
)
|
||||
bounds = end_effector_bounds or _DEFAULT_EE_BOUNDS
|
||||
return RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
EEBoundsAndSafety(end_effector_bounds=bounds, max_ee_step_m=max_ee_step_m),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
@@ -74,11 +74,11 @@ class SOFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -176,7 +176,7 @@ class SOFollower(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -194,7 +194,7 @@ class SOFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -170,7 +170,7 @@ class UnitreeG1(Robot):
|
||||
time.sleep(sleep_time)
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
@@ -273,7 +273,7 @@ class UnitreeG1(Robot):
|
||||
for cam in self._cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return {}
|
||||
@@ -351,10 +351,10 @@ class UnitreeG1(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
for motor in G1_29_JointIndex:
|
||||
key = f"{motor.name}.q"
|
||||
if key in action:
|
||||
@@ -421,7 +421,7 @@ class UnitreeG1(Robot):
|
||||
num_steps = int(total_time / control_dt)
|
||||
|
||||
# get current state
|
||||
obs = self._get_observation()
|
||||
obs = self.get_observation()
|
||||
|
||||
# record current positions
|
||||
init_dof_pos = np.zeros(29, dtype=np.float32)
|
||||
@@ -439,7 +439,7 @@ class UnitreeG1(Robot):
|
||||
interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha
|
||||
action_dict[f"{motor.name}.q"] = float(interp_pos)
|
||||
|
||||
self._send_action(action_dict)
|
||||
self.send_action(action_dict)
|
||||
|
||||
# Maintain constant control rate
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -56,7 +56,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
|
||||
@@ -43,6 +43,17 @@ lerobot-eval \
|
||||
|
||||
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
|
||||
|
||||
You can also evaluate a model on a Hub environment with custom kwargs:
|
||||
```
|
||||
lerobot-eval \
|
||||
--policy.path=HF_USER/HF_REPO \
|
||||
--env=HF_USER/HF_REPO \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--env_kwargs.environment=env_A \
|
||||
--env_kwargs.embodiment=emb_B \
|
||||
```
|
||||
|
||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||
"""
|
||||
|
||||
@@ -521,6 +532,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
n_envs=cfg.eval.batch_size,
|
||||
use_async_envs=cfg.eval.use_async_envs,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
**cfg.env_kwargs,
|
||||
)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
@@ -61,7 +61,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -74,8 +74,6 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -87,16 +85,19 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_processors,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
@@ -124,7 +125,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
@@ -139,11 +139,6 @@ from lerobot.utils.control_utils import (
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.pipeline_utils import (
|
||||
build_dataset_features,
|
||||
check_action_space_compatibility,
|
||||
check_observation_space_compatibility,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
@@ -230,9 +225,6 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
|
||||
# Only applies when using a policy (not teleop)
|
||||
interpolation_multiplier: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -256,23 +248,28 @@ class RecordConfig:
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
[ Robot ]
|
||||
V
|
||||
[ robot.get_observation() ] → applies output_pipeline internally → obs
|
||||
[ robot.get_observation() ] ---> raw_obs
|
||||
V
|
||||
[ robot_observation_processor ] ---> processed_obs
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| teleop.get_action() | predict_action(obs)
|
||||
| (output_pipeline applied internally) | |
|
||||
| | | V
|
||||
'----> action '---> policy_action_dict
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> processed_teleop_action '---> processed_policy_action
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot.send_action(action) ]
|
||||
(input_pipeline applied internally)
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
( Save action + obs to Dataset )
|
||||
[ robot.send_action() ] -- (Robot Executes)
|
||||
V
|
||||
( Save to Dataset )
|
||||
V
|
||||
( Rerun Log / Loop Wait )
|
||||
"""
|
||||
@@ -283,6 +280,15 @@ def record_loop(
|
||||
robot: Robot,
|
||||
events: dict,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
], # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
], # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline[
|
||||
RobotObservation, RobotObservation
|
||||
], # runs after robot
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
@@ -291,30 +297,8 @@ def record_loop(
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
"""
|
||||
Core recording loop. Robot and teleoperator pipelines are applied internally —
|
||||
no explicit processor arguments are needed.
|
||||
|
||||
Args:
|
||||
robot: The robot instance. Its output_pipeline() transforms observations and
|
||||
its input_pipeline() transforms actions before hardware write.
|
||||
events: Control events dict (exit_early, stop_recording, rerecord_episode).
|
||||
fps: Target control loop frequency.
|
||||
dataset: If provided, frames are written here each step.
|
||||
teleop: Teleoperator or list of teleoperators. Its output_pipeline() transforms
|
||||
actions (e.g., joint → EE) before they are sent to the robot.
|
||||
policy: Optional pre-trained policy for closed-loop control.
|
||||
preprocessor: Policy input pre-processor.
|
||||
postprocessor: Policy output post-processor.
|
||||
control_time_s: Episode duration in seconds.
|
||||
single_task: Task description string saved with each frame.
|
||||
display_data: If True, log observations and actions to Rerun.
|
||||
interpolator: Optional action interpolator for smoother policy control.
|
||||
display_compressed_images: If True, compress images before Rerun display.
|
||||
"""
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
@@ -349,17 +333,6 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset interpolator if provided
|
||||
if interpolator is not None:
|
||||
interpolator.reset()
|
||||
|
||||
# Calculate control interval based on interpolation
|
||||
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
|
||||
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
|
||||
# Pre-compute once — action features don't change during a recording episode
|
||||
action_keys = sorted(robot.action_features) if use_interpolation else []
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -369,85 +342,65 @@ def record_loop(
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation (output_pipeline applied internally)
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# With interpolation: only call policy when interpolator needs new action
|
||||
if use_interpolation:
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
# send_action applies input_pipeline (e.g. IK) internally;
|
||||
# capture the actually-sent joint action for interpolation
|
||||
sent_joint_action = robot.send_action(act_processed_policy)
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
# Build interpolation tensor from the motor-level joint action
|
||||
action_tensor = torch.tensor([sent_joint_action[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
# Get interpolated action (in joint/motor space)
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
action_values = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
# Interpolated values are already in joint space; bypass IK pipeline
|
||||
robot._send_action(action_values)
|
||||
else:
|
||||
# No action available yet, skip this iteration
|
||||
continue
|
||||
else:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
# send_action applies input_pipeline (e.g. IK) internally
|
||||
robot.send_action(act_processed_policy)
|
||||
action_values = act_processed_policy
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
# get_action applies output_pipeline (e.g. FK) internally
|
||||
action_values = teleop.get_action()
|
||||
# send_action applies input_pipeline (e.g. IK) internally
|
||||
robot.send_action(action_values)
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# LeKiwi multi-teleop path
|
||||
arm_action = teleop_arm.get_action() # output_pipeline applied internally
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
action_values = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
robot.send_action(action_values) # input_pipeline applied internally
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
else:
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
if policy is not None and act_processed_policy is not None:
|
||||
action_values = act_processed_policy
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
else:
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
@@ -456,12 +409,12 @@ def record_loop(
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs, action=action_values, compress_images=display_compressed_images
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = control_interval - dt_s
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
@@ -487,9 +440,22 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
|
||||
|
||||
# Dataset features derived automatically from robot/teleop pipelines.
|
||||
# When teleop is None (policy-only recording), only observation features are included.
|
||||
dataset_features = build_dataset_features(robot, teleop, use_videos=cfg.dataset.video)
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(
|
||||
action=robot.action_features
|
||||
), # TODO(steven, pepijn): in future this should be come from teleop or policy
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
@@ -535,7 +501,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
interpolator = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
@@ -546,19 +511,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
# Create interpolator for smoother policy control
|
||||
if cfg.interpolation_multiplier > 1:
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
teleop.connect()
|
||||
|
||||
if teleop is not None:
|
||||
check_action_space_compatibility(teleop, robot)
|
||||
check_observation_space_compatibility(robot, teleop)
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
@@ -574,6 +531,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -582,7 +542,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
interpolator=interpolator,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
|
||||
@@ -601,6 +560,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
|
||||
@@ -47,6 +47,9 @@ from pprint import pformat
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import (
|
||||
make_default_robot_action_processor,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -77,7 +80,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
@@ -96,6 +99,8 @@ def replay(cfg: ReplayConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
|
||||
@@ -115,10 +120,11 @@ def replay(cfg: ReplayConfig):
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
# Update cached observation so the robot's input pipeline can use it (e.g. for IK)
|
||||
robot.get_observation()
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
_ = robot.send_action(action)
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
@@ -43,7 +43,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -52,7 +51,6 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -61,6 +61,12 @@ import rerun as rr
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_processors,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -88,13 +94,11 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.pipeline_utils import check_action_space_compatibility, check_observation_space_compatibility
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
@@ -122,28 +126,28 @@ def teleop_loop(
|
||||
teleop: Teleoperator,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
|
||||
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
|
||||
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
|
||||
display_data: bool = False,
|
||||
duration: float | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
"""
|
||||
Continuously reads actions from a teleoperation device, sends them to a robot,
|
||||
and optionally displays the robot's state. Pipelines are applied internally by
|
||||
the robot and teleoperator objects.
|
||||
|
||||
The loop runs at the specified frequency until a set duration is reached or it
|
||||
is manually interrupted.
|
||||
This function continuously reads actions from a teleoperation device, processes them through optional
|
||||
pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
|
||||
specified frequency until a set duration is reached or it is manually interrupted.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator device instance providing control actions.
|
||||
robot: The robot instance being controlled.
|
||||
fps: The target frequency for the control loop in frames per second.
|
||||
display_data: If True, fetches robot observations and displays them in the
|
||||
console and Rerun.
|
||||
display_compressed_images: If True, compresses images before sending them
|
||||
to Rerun for display.
|
||||
duration: The maximum duration of the teleoperation loop in seconds.
|
||||
If None, the loop runs indefinitely.
|
||||
display_data: If True, fetches robot observations and displays them in the console and Rerun.
|
||||
display_compressed_images: If True, compresses images before sending them to Rerun for display.
|
||||
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
|
||||
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
|
||||
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
||||
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
||||
"""
|
||||
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
@@ -152,29 +156,40 @@ def teleop_loop(
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get teleop action (output_pipeline applied internally)
|
||||
action = teleop.get_action()
|
||||
# Get robot observation
|
||||
# Not really needed for now other than for visualization
|
||||
# teleop_action_processor can take None as an observation
|
||||
# given that it is the identity processor as default
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Send action to robot (input_pipeline applied internally)
|
||||
robot_action_sent = robot.send_action(action)
|
||||
# Get teleop action
|
||||
raw_action = teleop.get_action()
|
||||
|
||||
# Process teleop action through pipeline
|
||||
teleop_action = teleop_action_processor((raw_action, obs))
|
||||
|
||||
# Process action for robot through pipeline
|
||||
robot_action_to_send = robot_action_processor((teleop_action, obs))
|
||||
|
||||
# Send processed action to robot (robot_action_processor.to_output should return RobotAction)
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
if display_data:
|
||||
# Get robot observation (output_pipeline applied internally)
|
||||
obs = robot.get_observation()
|
||||
teleop.send_feedback(obs)
|
||||
# Process robot observation through pipeline
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
log_rerun_data(
|
||||
observation=obs,
|
||||
action=action,
|
||||
observation=obs_transition,
|
||||
action=teleop_action,
|
||||
compress_images=display_compressed_images,
|
||||
)
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
for motor, value in robot_action_sent.items():
|
||||
if isinstance(value, float | int):
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_sent) + 3)
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 3)
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
@@ -200,13 +215,11 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
teleop.connect()
|
||||
robot.connect()
|
||||
|
||||
check_action_space_compatibility(teleop, robot)
|
||||
check_observation_space_compatibility(robot, teleop)
|
||||
|
||||
try:
|
||||
teleop_loop(
|
||||
teleop=teleop,
|
||||
@@ -214,6 +227,9 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
fps=cfg.fps,
|
||||
display_data=cfg.display_data,
|
||||
duration=cfg.teleop_time_s,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -52,7 +51,6 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -380,10 +378,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
# Use effective batch size for proper epoch calculation in distributed training
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
effective_batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
@@ -392,14 +390,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -424,8 +414,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -519,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -72,9 +72,9 @@ class BiOpenArmLeader(Teleoperator):
|
||||
self.right_arm = OpenArmLeader(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.raw_action_features
|
||||
right_arm_features = self.right_arm.raw_action_features
|
||||
def action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.action_features
|
||||
right_arm_features = self.right_arm.action_features
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_features.items()},
|
||||
@@ -82,7 +82,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -112,7 +112,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -125,7 +125,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -55,9 +55,9 @@ class BiSOLeader(Teleoperator):
|
||||
self.right_arm = SOLeader(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.raw_action_features
|
||||
right_arm_features = self.right_arm.raw_action_features
|
||||
def action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.action_features
|
||||
right_arm_features = self.right_arm.action_features
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_features.items()},
|
||||
@@ -65,7 +65,7 @@ class BiSOLeader(Teleoperator):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -94,7 +94,7 @@ class BiSOLeader(Teleoperator):
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -107,7 +107,7 @@ class BiSOLeader(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class GamepadTeleop(Teleoperator):
|
||||
self.gamepad = None
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
def action_features(self) -> dict:
|
||||
if self.config.use_gripper:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
@@ -72,7 +72,7 @@ class GamepadTeleop(Teleoperator):
|
||||
}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict:
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
def connect(self) -> None:
|
||||
@@ -87,7 +87,7 @@ class GamepadTeleop(Teleoperator):
|
||||
self.gamepad.start()
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
# Update the controller to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
@@ -180,7 +180,7 @@ class GamepadTeleop(Teleoperator):
|
||||
# No additional configuration needed
|
||||
pass
|
||||
|
||||
def _send_feedback(self, feedback: dict) -> None:
|
||||
def send_feedback(self, feedback: dict) -> None:
|
||||
"""Send feedback to the gamepad."""
|
||||
# Gamepad doesn't support feedback
|
||||
pass
|
||||
|
||||
@@ -81,11 +81,11 @@ class HomunculusArm(Teleoperator):
|
||||
self.state_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
def action_features(self) -> dict:
|
||||
return {f"{joint}.pos": float for joint in self.joints}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict:
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -298,11 +298,11 @@ class HomunculusArm(Teleoperator):
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
joint_positions = self._read()
|
||||
return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -107,11 +107,11 @@ class HomunculusGlove(Teleoperator):
|
||||
self.state_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
def action_features(self) -> dict:
|
||||
return {f"{joint}.pos": float for joint in self.joints}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict:
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -324,13 +324,13 @@ class HomunculusGlove(Teleoperator):
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
joint_positions = self._read()
|
||||
return homunculus_glove_to_hope_jr_hand(
|
||||
{f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
)
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -67,7 +67,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
self.logs = {}
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
def action_features(self) -> dict:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (len(self.arm),),
|
||||
@@ -75,7 +75,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict:
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -122,7 +122,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
self._drain_pressed_keys()
|
||||
@@ -133,7 +133,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
|
||||
return dict.fromkeys(action, None)
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -157,7 +157,7 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
self.misc_keys_queue = Queue()
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
def action_features(self) -> dict:
|
||||
if self.config.use_gripper:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
@@ -172,7 +172,7 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
self._drain_pressed_keys()
|
||||
delta_x = 0.0
|
||||
delta_y = 0.0
|
||||
@@ -338,7 +338,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
self.current_angular_speed = config.angular_speed
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
def action_features(self) -> dict:
|
||||
"""Return action format for rover (linear and angular velocities)."""
|
||||
return {
|
||||
"linear.vel": float,
|
||||
@@ -361,7 +361,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
self.current_pressed.pop(key_char, None)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get the current action based on pressed keys.
|
||||
|
||||
|
||||
@@ -58,11 +58,11 @@ class KochLeader(Teleoperator):
|
||||
)
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -160,7 +160,7 @@ class KochLeader(Teleoperator):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -168,7 +168,7 @@ class KochLeader(Teleoperator):
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -57,11 +57,11 @@ class OmxLeader(Teleoperator):
|
||||
)
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -149,7 +149,7 @@ class OmxLeader(Teleoperator):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -157,7 +157,7 @@ class OmxLeader(Teleoperator):
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenArmLeader(Teleoperator):
|
||||
)
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Features produced by this teleoperator."""
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
@@ -75,7 +75,7 @@ class OpenArmLeader(Teleoperator):
|
||||
return features
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
"""Feedback features (not implemented for OpenArms)."""
|
||||
return {}
|
||||
|
||||
@@ -183,7 +183,7 @@ class OpenArmLeader(Teleoperator):
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get current action from the leader arm.
|
||||
|
||||
@@ -209,7 +209,7 @@ class OpenArmLeader(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
|
||||
use_degrees: bool = True
|
||||
@@ -1,372 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted on the leader side.
|
||||
LEFT_MOTORS_TO_FLIP = {"joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"}
|
||||
RIGHT_MOTORS_TO_FLIP = {"joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"}
|
||||
# Leader(OpenArmMini) -> Follower(OpenArms) joint remap
|
||||
JOINT_REMAP_TO_OPENARMS = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
||||
# Follower(OpenArms) -> Leader(OpenArmMini) joint remap
|
||||
JOINT_REMAP_TO_MINI = {"joint_7": "joint_6", "joint_6": "joint_7"}
|
||||
OPENARMS_GRIPPER_MIN = -65.0
|
||||
OPENARMS_GRIPPER_MAX = 0.0
|
||||
MINI_GRIPPER_MIN = 0.0
|
||||
MINI_GRIPPER_MAX = 100.0
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
name = "openarm_mini"
|
||||
|
||||
def __init__(self, config: OpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mini_gripper_to_openarms(value: float) -> float:
|
||||
"""Convert OpenArmMini gripper range [0, 100] to OpenArms gripper range [-65, 0]."""
|
||||
mapped = OPENARMS_GRIPPER_MAX + (
|
||||
(value - MINI_GRIPPER_MIN)
|
||||
* (OPENARMS_GRIPPER_MIN - OPENARMS_GRIPPER_MAX)
|
||||
/ (MINI_GRIPPER_MAX - MINI_GRIPPER_MIN)
|
||||
)
|
||||
return max(min(mapped, OPENARMS_GRIPPER_MAX), OPENARMS_GRIPPER_MIN)
|
||||
|
||||
@staticmethod
|
||||
def _openarms_gripper_to_mini(value: float) -> float:
|
||||
"""Convert OpenArms gripper range [-65, 0] to OpenArmMini gripper range [0, 100]."""
|
||||
clipped = max(min(value, OPENARMS_GRIPPER_MAX), OPENARMS_GRIPPER_MIN)
|
||||
return MINI_GRIPPER_MIN + (
|
||||
(OPENARMS_GRIPPER_MAX - clipped)
|
||||
* (MINI_GRIPPER_MAX - MINI_GRIPPER_MIN)
|
||||
/ (OPENARMS_GRIPPER_MAX - OPENARMS_GRIPPER_MIN)
|
||||
)
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
user_input = input(
|
||||
f"Press ENTER to use existing calibration for {self.id}, "
|
||||
f"or type 'c' and press ENTER to run new calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
range_min = int(closed_pos)
|
||||
range_max = int(open_pos)
|
||||
drive_mode = 0
|
||||
else:
|
||||
range_min = int(open_pos)
|
||||
range_max = int(closed_pos)
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_min,
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
target_motor = JOINT_REMAP_TO_OPENARMS.get(motor, motor)
|
||||
mapped_val = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
if target_motor == "gripper":
|
||||
mapped_val = self._mini_gripper_to_openarms(mapped_val)
|
||||
action[f"right_{target_motor}.pos"] = mapped_val
|
||||
for motor, val in left_positions.items():
|
||||
target_motor = JOINT_REMAP_TO_OPENARMS.get(motor, motor)
|
||||
mapped_val = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
if target_motor == "gripper":
|
||||
mapped_val = self._mini_gripper_to_openarms(mapped_val)
|
||||
action[f"left_{target_motor}.pos"] = mapped_val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def enable_torque(self) -> None:
|
||||
"""Enable torque on both arms for active motion commands."""
|
||||
self.bus_right.enable_torque()
|
||||
self.bus_left.enable_torque()
|
||||
|
||||
@check_if_not_connected
|
||||
def disable_torque(self) -> None:
|
||||
"""Disable torque on both arms for manual teleoperation."""
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_left.disable_torque()
|
||||
|
||||
@check_if_not_connected
|
||||
def write_goal_positions(self, action: dict[str, float]) -> None:
|
||||
"""Send normalized bilateral goal positions to the underlying Feetech buses."""
|
||||
right_goals: dict[str, float] = {}
|
||||
left_goals: dict[str, float] = {}
|
||||
|
||||
for key, value in action.items():
|
||||
if not key.endswith(".pos"):
|
||||
continue
|
||||
|
||||
if key.startswith("right_"):
|
||||
openarms_motor = key.removeprefix("right_").removesuffix(".pos")
|
||||
mini_motor = JOINT_REMAP_TO_MINI.get(openarms_motor, openarms_motor)
|
||||
mapped_val = self._openarms_gripper_to_mini(value) if openarms_motor == "gripper" else value
|
||||
right_goals[mini_motor] = -mapped_val if mini_motor in RIGHT_MOTORS_TO_FLIP else mapped_val
|
||||
elif key.startswith("left_"):
|
||||
openarms_motor = key.removeprefix("left_").removesuffix(".pos")
|
||||
mini_motor = JOINT_REMAP_TO_MINI.get(openarms_motor, openarms_motor)
|
||||
mapped_val = self._openarms_gripper_to_mini(value) if openarms_motor == "gripper" else value
|
||||
left_goals[mini_motor] = -mapped_val if mini_motor in LEFT_MOTORS_TO_FLIP else mapped_val
|
||||
|
||||
if right_goals:
|
||||
self.bus_right.sync_write("Goal_Position", right_goals)
|
||||
if left_goals:
|
||||
self.bus_left.sync_write("Goal_Position", left_goals)
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
"""Route feedback position commands through the same OpenArms/OpenArmMini mapping."""
|
||||
self.write_goal_positions(feedback)
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -47,7 +47,7 @@ class BasePhone:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
@@ -56,15 +56,15 @@ class BasePhone:
|
||||
}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
return {}
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -163,7 +163,7 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
return True, pos, rot, pose
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict:
|
||||
def get_action(self) -> dict:
|
||||
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose or not self.is_calibrated:
|
||||
return {}
|
||||
@@ -314,7 +314,7 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
self._latest_message = message
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict:
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
@@ -395,21 +395,21 @@ class Phone(Teleoperator):
|
||||
return self._phone_impl.is_calibrated
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.raw_action_features
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.action_features
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.raw_feedback_features
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.feedback_features
|
||||
|
||||
def configure(self) -> None:
|
||||
return self._phone_impl.configure()
|
||||
|
||||
def _get_action(self) -> dict:
|
||||
return self._phone_impl._get_action()
|
||||
def get_action(self) -> dict:
|
||||
return self._phone_impl.get_action()
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
return self._phone_impl._send_feedback(feedback)
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
return self._phone_impl.send_feedback(feedback)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return self._phone_impl.disconnect()
|
||||
|
||||
@@ -104,7 +104,7 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
return joints
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
@@ -120,7 +120,7 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -146,7 +146,7 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
|
||||
joint_action: dict[str, float] = {}
|
||||
@@ -168,7 +168,7 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return {**joint_action, **vel_action}
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .ee_space import make_so10x_leader_fk_pipeline
|
||||
|
||||
__all__ = ["make_so10x_leader_fk_pipeline"]
|
||||
@@ -1,82 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Forward-kinematics pipeline for SO-100/101 leader (teleoperator) arm.
|
||||
|
||||
Converts raw leader joint positions into end-effector pose. Attach this to a leader
|
||||
via ``set_output_pipeline`` so that ``get_action()`` returns EE coordinates instead of
|
||||
raw joint angles.
|
||||
|
||||
Example::
|
||||
|
||||
from lerobot.teleoperators.so_leader.pipelines import make_so10x_leader_fk_pipeline
|
||||
|
||||
motor_names = list(leader.bus.motors.keys())
|
||||
leader.set_output_pipeline(make_so10x_leader_fk_pipeline(URDF_PATH, motor_names))
|
||||
action = leader.get_action() # now contains ee.x, ee.y, ee.z, ...
|
||||
"""
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import ForwardKinematicsJointsToEE
|
||||
|
||||
_DEFAULT_GRIPPER_FRAME = "gripper_frame_link"
|
||||
|
||||
|
||||
def make_so10x_leader_fk_pipeline(
|
||||
urdf_path: str,
|
||||
motor_names: list[str],
|
||||
*,
|
||||
target_frame_name: str = _DEFAULT_GRIPPER_FRAME,
|
||||
) -> RobotProcessorPipeline[RobotAction, RobotAction]:
|
||||
"""
|
||||
Create a forward-kinematics action pipeline for SO-100/101 leader teleoperators.
|
||||
|
||||
Converts raw leader joint positions (action) into end-effector pose (position +
|
||||
orientation + gripper). Attach this to a leader via ``set_output_pipeline`` so that
|
||||
``get_action()`` returns EE coordinates instead of raw joint angles.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the SO-100/101 URDF file used for kinematics.
|
||||
motor_names: Ordered list of motor names matching the URDF joint names.
|
||||
target_frame_name: Name of the end-effector frame in the URDF.
|
||||
|
||||
Returns:
|
||||
A RobotProcessorPipeline that maps joint actions to EE actions.
|
||||
|
||||
Example::
|
||||
|
||||
motor_names = list(leader.bus.motors.keys())
|
||||
leader.set_output_pipeline(
|
||||
make_so10x_leader_fk_pipeline("./so101.urdf", motor_names)
|
||||
)
|
||||
action = leader.get_action() # returns ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_vel
|
||||
"""
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path=urdf_path,
|
||||
target_frame_name=target_frame_name,
|
||||
joint_names=motor_names,
|
||||
)
|
||||
return RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics, motor_names=motor_names)],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
@@ -55,11 +55,11 @@ class SOLeader(Teleoperator):
|
||||
)
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -138,7 +138,7 @@ class SOLeader(Teleoperator):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -146,7 +146,7 @@ class SOLeader(Teleoperator):
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -12,23 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.motors.motors_bus import MotorCalibration
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.processor.core import RobotAction
|
||||
from lerobot.processor.pipeline import RobotProcessorPipeline
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
|
||||
|
||||
@@ -39,10 +33,6 @@ class Teleoperator(abc.ABC):
|
||||
This class provides a standardized interface for interacting with physical teleoperators.
|
||||
Subclasses must implement all abstract methods and properties to be usable.
|
||||
|
||||
Pipelines are first-class citizens: every teleoperator carries an optional output pipeline
|
||||
(applied in get_action()) and an optional input pipeline (applied in send_feedback()).
|
||||
Both default to identity (no-op), so existing teleoperators work without any changes.
|
||||
|
||||
Attributes:
|
||||
config_class (RobotConfig): The expected configuration class for this teleoperator.
|
||||
name (str): The unique name used to identify this teleoperator type.
|
||||
@@ -65,14 +55,6 @@ class Teleoperator(abc.ABC):
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
# Pipeline interface — default to identity (no-op), swap via set_output/input_pipeline()
|
||||
# Lazy import: factory is in lerobot.processor which loads after teleoperators at module init time,
|
||||
# but __init__ runs at instance-creation time when lerobot.processor is fully loaded.
|
||||
from lerobot.processor.factory import _make_identity_feedback_pipeline, _make_identity_teleop_action_pipeline
|
||||
|
||||
self._output_pipeline: RobotProcessorPipeline = _make_identity_teleop_action_pipeline()
|
||||
self._input_pipeline: RobotProcessorPipeline = _make_identity_feedback_pipeline()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
@@ -102,114 +84,38 @@ class Teleoperator(abc.ABC):
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
# ── Pipeline interface ────────────────────────────────────────────────────
|
||||
|
||||
def output_pipeline(self) -> RobotProcessorPipeline:
|
||||
"""
|
||||
Pipeline applied inside get_action() to transform raw hardware actions.
|
||||
Default: identity (no-op). Override via set_output_pipeline() or subclassing.
|
||||
|
||||
Example: set a forward-kinematics pipeline to convert leader joint positions to EE pose.
|
||||
"""
|
||||
return self._output_pipeline
|
||||
|
||||
def input_pipeline(self) -> RobotProcessorPipeline:
|
||||
"""
|
||||
Pipeline applied inside send_feedback() to transform incoming feedback.
|
||||
Default: identity (no-op). Override via set_input_pipeline() or subclassing.
|
||||
"""
|
||||
return self._input_pipeline
|
||||
|
||||
def set_output_pipeline(self, pipeline: RobotProcessorPipeline) -> None:
|
||||
"""Set the action output pipeline (applied in get_action())."""
|
||||
self._output_pipeline = pipeline
|
||||
|
||||
def set_input_pipeline(self, pipeline: RobotProcessorPipeline) -> None:
|
||||
"""Set the feedback input pipeline (applied in send_feedback())."""
|
||||
self._input_pipeline = pipeline
|
||||
|
||||
# ── Feature properties ────────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_features(self) -> dict:
|
||||
"""
|
||||
Pipeline-transformed action features.
|
||||
A dictionary describing the structure and types of the actions produced by the teleoperator. Its
|
||||
structure (keys) should match the structure of what is returned by :pymeth:`get_action`. Values for
|
||||
the dict should be the type of the value if it's a simple value, e.g. `float` for single
|
||||
proprioceptive value (a joint's goal position/velocity)
|
||||
|
||||
Applies output_pipeline().transform_features() to raw_action_features so the
|
||||
returned dict matches what get_action() actually produces for callers.
|
||||
|
||||
Use raw_action_features to inspect hardware-level feature shapes.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the
|
||||
teleoperator is connected or not.
|
||||
"""
|
||||
from lerobot.datasets.pipeline_features import create_initial_features # lazy import
|
||||
|
||||
initial = create_initial_features(action=self.raw_action_features)
|
||||
transformed = self.output_pipeline().transform_features(initial)
|
||||
return transformed.get(PipelineFeatureType.ACTION, {})
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def raw_action_features(self) -> dict:
|
||||
"""
|
||||
Hardware-level action features (before any pipeline transformation).
|
||||
|
||||
A dictionary describing the structure and types of the actions produced
|
||||
directly by the teleoperator hardware. Its structure (keys) should match
|
||||
the structure of what is returned by :pymeth:`_get_action`. Values should be
|
||||
the type of the value if it's a simple value, e.g. ``float`` for single
|
||||
proprioceptive value (a joint's goal position/velocity).
|
||||
|
||||
Note: this property should be able to be called regardless of whether the
|
||||
teleoperator is connected or not.
|
||||
Note: this property should be able to be called regardless of whether the robot is connected or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def raw_feedback_features(self) -> dict:
|
||||
"""
|
||||
Hardware-level feedback features (before any pipeline transformation).
|
||||
|
||||
A dictionary describing the structure and types of the feedback accepted directly
|
||||
by the teleoperator hardware (i.e. what :pymeth:`_send_feedback` receives). Its
|
||||
structure (keys) should match the structure of what is expected by
|
||||
:pymeth:`_send_feedback`. Values should be the type of the value if it's a simple
|
||||
value, e.g. ``float`` for single proprioceptive value.
|
||||
|
||||
Return an empty dict if this teleoperator does not support feedback.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the
|
||||
teleoperator is connected or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict:
|
||||
"""
|
||||
Pipeline-transformed feedback features.
|
||||
A dictionary describing the structure and types of the feedback actions expected by the robot. Its
|
||||
structure (keys) should match the structure of what is passed to :pymeth:`send_feedback`. Values for
|
||||
the dict should be the type of the value if it's a simple value, e.g. `float` for single
|
||||
proprioceptive value (a joint's goal position/velocity)
|
||||
|
||||
Applies input_pipeline().transform_features() to raw_feedback_features so the
|
||||
returned dict reflects what the input pipeline outputs to the teleoperator hardware.
|
||||
|
||||
Use raw_feedback_features to inspect hardware-level feedback feature shapes.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the
|
||||
teleoperator is connected or not.
|
||||
Note: this property should be able to be called regardless of whether the robot is connected or not.
|
||||
"""
|
||||
from lerobot.datasets.pipeline_features import create_initial_features # lazy import
|
||||
|
||||
initial = create_initial_features(observation=self.raw_feedback_features)
|
||||
transformed = self.input_pipeline().transform_features(initial)
|
||||
return transformed.get(PipelineFeatureType.OBSERVATION, {})
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Whether the teleoperator is currently connected or not. If ``False``, calling
|
||||
:pymeth:`get_action` or :pymeth:`send_feedback` should raise an error.
|
||||
Whether the teleoperator is currently connected or not. If `False`, calling :pymeth:`get_action`
|
||||
or :pymeth:`send_feedback` should raise an error.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -227,7 +133,7 @@ class Teleoperator(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Whether the teleoperator is currently calibrated or not. Should be always ``True`` if not applicable"""
|
||||
"""Whether the teleoperator is currently calibrated or not. Should be always `True` if not applicable"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -245,7 +151,7 @@ class Teleoperator(abc.ABC):
|
||||
Helper to load calibration data from the specified file.
|
||||
|
||||
Args:
|
||||
fpath (Path | None): Optional path to the calibration file. Defaults to ``self.calibration_fpath``.
|
||||
fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`.
|
||||
"""
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath) as f, draccus.config_type("json"):
|
||||
@@ -256,7 +162,7 @@ class Teleoperator(abc.ABC):
|
||||
Helper to save calibration data to the specified file.
|
||||
|
||||
Args:
|
||||
fpath (Path | None): Optional path to save the calibration file. Defaults to ``self.calibration_fpath``.
|
||||
fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`.
|
||||
"""
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath, "w") as f, draccus.config_type("json"):
|
||||
@@ -270,51 +176,29 @@ class Teleoperator(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# ── Template methods (concrete, call pipeline internally) ─────────────────
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Retrieve the current action from the teleoperator and apply the output pipeline.
|
||||
|
||||
Calls :pymeth:`_get_action` to get raw hardware data, then applies
|
||||
:pymeth:`output_pipeline`.
|
||||
Retrieve the current action from the teleoperator.
|
||||
|
||||
Returns:
|
||||
RobotAction: Pipeline-transformed action. With the default identity pipeline
|
||||
this equals the raw action from :pymeth:`_get_action`.
|
||||
"""
|
||||
raw = self._get_action()
|
||||
return self.output_pipeline()(raw)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_action(self) -> RobotAction:
|
||||
"""
|
||||
Retrieve the raw action directly from teleoperator hardware.
|
||||
|
||||
Returns:
|
||||
RobotAction: A flat dictionary representing the teleoperator's current actions.
|
||||
Its structure should match :pymeth:`raw_action_features`.
|
||||
RobotAction: A flat dictionary representing the teleoperator's current actions. Its
|
||||
structure should match :pymeth:`observation_features`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
"""
|
||||
Apply the input pipeline and send the resulting feedback to teleoperator hardware.
|
||||
Send a feedback action command to the teleoperator.
|
||||
|
||||
Args:
|
||||
feedback (dict[str, Any]): Dictionary representing the desired feedback.
|
||||
Its structure should match :pymeth:`feedback_features`.
|
||||
"""
|
||||
transformed = self.input_pipeline()(feedback)
|
||||
self._send_feedback(transformed)
|
||||
feedback (dict[str, Any]): Dictionary representing the desired feedback. Its structure should match
|
||||
:pymeth:`feedback_features`.
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
"""
|
||||
Send feedback directly to teleoperator hardware.
|
||||
|
||||
Args:
|
||||
feedback (dict[str, Any]): Dictionary of hardware-level feedback commands.
|
||||
Returns:
|
||||
dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by
|
||||
safety limits on velocity.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -72,11 +72,11 @@ class UnitreeG1Teleoperator(Teleoperator):
|
||||
self.ik_helper: ExoskeletonIKHelper | None = None
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{name}.q": float for name in self._g1_joint_names}
|
||||
|
||||
@cached_property
|
||||
def raw_feedback_features(self) -> dict[str, type]:
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -114,12 +114,12 @@ class UnitreeG1Teleoperator(Teleoperator):
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> dict[str, float]:
|
||||
left_angles = self.left_arm.get_angles()
|
||||
right_angles = self.right_arm.get_angles()
|
||||
return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
||||
|
||||
def _send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Exoskeleton arms do not support feedback")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
|
||||
@@ -95,10 +95,6 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
elif config.type == "openarm_mini":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
|
||||
@@ -104,10 +104,9 @@ class MetricsTracker:
|
||||
self.metrics = metrics
|
||||
|
||||
self.steps = initial_step
|
||||
world_size = accelerator.num_processes if accelerator else 1
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
||||
self.samples = self.steps * self._batch_size * world_size
|
||||
self.samples = self.steps * self._batch_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
self.accelerator = accelerator
|
||||
@@ -133,8 +132,7 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
world_size = self.accelerator.num_processes if self.accelerator else 1
|
||||
self.samples += self._batch_size * world_size
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utilities for building dataset features from robot/teleoperator pipelines and for
|
||||
checking action/observation space compatibility between teleops and robots.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from lerobot.datasets.utils import combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
# Prefixes stripped from feature keys to produce clean dataset names.
|
||||
# Handles both fully-qualified (e.g. "observation.state.ee.x") and short (e.g. "state.ee.x") forms.
|
||||
_PREFIXES_TO_STRIP = tuple(
|
||||
f"{token}."
|
||||
for const in (ACTION, OBS_STATE, OBS_IMAGES)
|
||||
for token in (const, const.split(".")[-1])
|
||||
)
|
||||
|
||||
_IMAGES_TOKEN = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
|
||||
def _should_keep(key: str, patterns: Sequence[str] | None) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
|
||||
def _strip_prefix(key: str) -> str:
|
||||
for prefix in _PREFIXES_TO_STRIP:
|
||||
if key.startswith(prefix):
|
||||
return key[len(prefix) :]
|
||||
return key
|
||||
|
||||
|
||||
def _features_to_dataset_spec(
|
||||
features: dict,
|
||||
*,
|
||||
is_action: bool,
|
||||
use_videos: bool,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a flat feature dict (as returned by ``robot.observation_features`` or
|
||||
``teleop.action_features``) into a LeRobot dataset feature specification.
|
||||
|
||||
Args:
|
||||
features: Flat dict mapping feature key → type or shape.
|
||||
is_action: True when ``features`` describes actions; False for observations.
|
||||
use_videos: When False, image observation features are excluded entirely.
|
||||
patterns: Optional regex patterns to filter state/action features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
Returns:
|
||||
A dict suitable for passing to ``LeRobotDataset.create(..., features=...)``.
|
||||
"""
|
||||
categorized: dict = {}
|
||||
for key, value in features.items():
|
||||
is_image = not is_action and (
|
||||
(isinstance(value, tuple) and len(value) == 3)
|
||||
or key.startswith(f"{OBS_IMAGES}.")
|
||||
or key.startswith(f"{_IMAGES_TOKEN}.")
|
||||
or f".{_IMAGES_TOKEN}." in key
|
||||
)
|
||||
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not _should_keep(key, patterns):
|
||||
continue
|
||||
|
||||
categorized[_strip_prefix(key)] = value
|
||||
|
||||
if not categorized:
|
||||
return {}
|
||||
|
||||
prefix = ACTION if is_action else OBS_STR
|
||||
return hw_to_dataset_features(categorized, prefix, use_videos)
|
||||
|
||||
|
||||
def build_dataset_features(
|
||||
robot,
|
||||
teleop=None,
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
action_features: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Derive dataset feature specifications from robot and teleoperator pipelines.
|
||||
|
||||
Reads ``robot.observation_features`` (which already reflects the robot's output
|
||||
pipeline transformation) and, when provided, ``teleop.action_features`` or an
|
||||
explicit ``action_features`` dict to determine what the dataset will store.
|
||||
|
||||
This replaces the old pattern of manually calling ``aggregate_pipeline_dataset_features``
|
||||
with explicit processor objects.
|
||||
|
||||
Args:
|
||||
robot: The robot instance (must have ``observation_features``).
|
||||
teleop: The teleoperator instance. When ``None`` and ``action_features`` is also
|
||||
``None`` (policy-only recording), only observation features are returned.
|
||||
use_videos: If True, image observations are included as video features.
|
||||
action_features: Explicit action feature dict, used when no teleop is available
|
||||
(e.g. evaluate/inference mode) but the dataset must match a specific action
|
||||
space (e.g. EE coordinates from a previously recorded dataset).
|
||||
|
||||
Returns:
|
||||
A combined feature dict suitable for passing to ``LeRobotDataset.create(..., features=...)``.
|
||||
|
||||
Example::
|
||||
|
||||
# Teleop recording
|
||||
features = build_dataset_features(follower, leader, use_videos=True)
|
||||
|
||||
# Policy-only recording (no teleop)
|
||||
features = build_dataset_features(robot, use_videos=True)
|
||||
|
||||
# Evaluate with explicit EE action space
|
||||
features = build_dataset_features(
|
||||
robot,
|
||||
use_videos=True,
|
||||
action_features={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
},
|
||||
)
|
||||
"""
|
||||
obs_ds = _features_to_dataset_spec(robot.observation_features, is_action=False, use_videos=use_videos)
|
||||
|
||||
if action_features is not None:
|
||||
act_ds = _features_to_dataset_spec(action_features, is_action=True, use_videos=False)
|
||||
elif teleop is not None:
|
||||
act_ds = _features_to_dataset_spec(teleop.action_features, is_action=True, use_videos=False)
|
||||
else:
|
||||
return obs_ds
|
||||
|
||||
return combine_feature_dicts(act_ds, obs_ds)
|
||||
|
||||
|
||||
def check_action_space_compatibility(teleop, robot) -> None:
|
||||
"""
|
||||
Warn if the teleoperator's pipeline-transformed action features don't match the robot's
|
||||
declared ``action_features``.
|
||||
|
||||
This is a soft check — a mismatch produces a warning but does not raise. It is intended
|
||||
to catch obvious misconfigurations (e.g., sending EE actions to a robot expecting joints)
|
||||
before the control loop starts.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator whose ``action_features`` describe what it sends.
|
||||
robot: The robot whose ``action_features`` describe what it expects.
|
||||
"""
|
||||
teleop_out = set(teleop.action_features.keys())
|
||||
robot_in = set(robot.action_features.keys())
|
||||
if teleop_out != robot_in:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Action space mismatch between teleop and robot.\n"
|
||||
f" Teleop sends: {sorted(teleop_out)}\n"
|
||||
f" Robot expects: {sorted(robot_in)}\n"
|
||||
"Ensure pipelines map between these spaces correctly.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
logging.debug("Action space compatibility check passed.")
|
||||
|
||||
|
||||
def check_observation_space_compatibility(robot, teleop) -> None:
|
||||
"""
|
||||
Warn if the robot's observation features don't cover what the teleoperator's
|
||||
``feedback_features`` expects.
|
||||
|
||||
A non-empty ``feedback_features`` that is not a subset of the robot's observation keys
|
||||
will produce a warning.
|
||||
|
||||
Args:
|
||||
robot: The robot whose ``observation_features`` describe what it produces.
|
||||
teleop: The teleoperator whose ``feedback_features`` describe what it expects as feedback.
|
||||
"""
|
||||
robot_obs = set(robot.observation_features.keys())
|
||||
teleop_feedback = set(teleop.feedback_features.keys())
|
||||
if teleop_feedback and not teleop_feedback.issubset(robot_obs):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Observation/feedback space mismatch.\n"
|
||||
f" Robot obs: {sorted(robot_obs)}\n"
|
||||
f" Teleop feedback expects: {sorted(teleop_feedback)}\n"
|
||||
"Ensure the robot observation pipeline covers all feedback keys.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
logging.debug("Observation/feedback space compatibility check passed.")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -266,3 +266,65 @@ def test_make_env_from_hub_async():
|
||||
|
||||
# clean up
|
||||
env.close()
|
||||
|
||||
|
||||
def test_make_env_from_hub_with_kwargs():
|
||||
"""Test that kwargs are correctly passed to hub environment's make_env."""
|
||||
hub_id = "lerobot/dummy-hub-env"
|
||||
|
||||
# Test with config_path kwarg
|
||||
envs_dict = make_env(
|
||||
hub_id,
|
||||
n_envs=1,
|
||||
trust_remote_code=True,
|
||||
config_path="/path/to/config.yaml",
|
||||
)
|
||||
env = envs_dict["cartpole_suite"][0]
|
||||
|
||||
assert hasattr(env, "hub_config")
|
||||
assert env.hub_config["config_path"] == "/path/to/config.yaml"
|
||||
env.close()
|
||||
|
||||
# Test with config_overrides dict
|
||||
envs_dict = make_env(
|
||||
hub_id,
|
||||
n_envs=1,
|
||||
trust_remote_code=True,
|
||||
config_overrides={"scene.object": "microwave", "sim.dt": 0.01},
|
||||
)
|
||||
env = envs_dict["cartpole_suite"][0]
|
||||
|
||||
assert env.hub_config["config_overrides"]["scene.object"] == "microwave"
|
||||
assert env.hub_config["config_overrides"]["sim.dt"] == 0.01
|
||||
env.close()
|
||||
|
||||
# Test with arbitrary extra kwargs
|
||||
envs_dict = make_env(
|
||||
hub_id,
|
||||
n_envs=1,
|
||||
trust_remote_code=True,
|
||||
custom_param="value",
|
||||
another_param=42,
|
||||
)
|
||||
env = envs_dict["cartpole_suite"][0]
|
||||
|
||||
assert env.hub_config["extra_kwargs"]["custom_param"] == "value"
|
||||
assert env.hub_config["extra_kwargs"]["another_param"] == 42
|
||||
env.close()
|
||||
|
||||
# Test combining config_path, config_overrides, and extra kwargs
|
||||
envs_dict = make_env(
|
||||
hub_id,
|
||||
n_envs=2,
|
||||
trust_remote_code=True,
|
||||
config_path="my_config.yaml",
|
||||
config_overrides={"robot": "gr1"},
|
||||
task_name="pick_and_place",
|
||||
)
|
||||
env = envs_dict["cartpole_suite"][0]
|
||||
|
||||
assert env.hub_config["config_path"] == "my_config.yaml"
|
||||
assert env.hub_config["config_overrides"]["robot"] == "gr1"
|
||||
assert env.hub_config["extra_kwargs"]["task_name"] == "pick_and_place"
|
||||
assert env.num_envs == 2
|
||||
env.close()
|
||||
|
||||
@@ -87,7 +87,7 @@ class MockRobot(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
@@ -116,7 +116,7 @@ class MockRobot(Robot):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
@@ -125,7 +125,7 @@ class MockRobot(Robot):
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -57,7 +57,7 @@ class MockTeleop(Teleoperator):
|
||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.motors}
|
||||
|
||||
@cached_property
|
||||
@@ -86,7 +86,7 @@ class MockTeleop(Teleoperator):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_action(self) -> RobotAction:
|
||||
def get_action(self) -> RobotAction:
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
@@ -95,7 +95,7 @@ class MockTeleop(Teleoperator):
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_feedback(self, feedback: dict[str, Any]) -> None: ...
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None: ...
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.pipeline_utils import _features_to_dataset_spec
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
EnvTransition,
|
||||
@@ -2040,68 +2040,102 @@ def test_features_remove_from_initial(policy_feature_factory):
|
||||
)
|
||||
|
||||
|
||||
# ── Tests for _features_to_dataset_spec ──────────────────────────────────────────────────────────
|
||||
# These replace the old aggregate_pipeline_dataset_features tests, covering the same categorisation
|
||||
# / filtering / prefix-stripping / HF-format logic via the private helper directly.
|
||||
@dataclass
|
||||
class AddActionEEAndJointFeatures(ProcessorStep):
|
||||
"""Adds both EE and JOINT action features."""
|
||||
|
||||
def __call__(self, tr):
|
||||
return tr
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# EE features
|
||||
features[PipelineFeatureType.ACTION]["action.ee.x"] = float
|
||||
features[PipelineFeatureType.ACTION]["action.ee.y"] = float
|
||||
# JOINT features
|
||||
features[PipelineFeatureType.ACTION]["action.j1.pos"] = float
|
||||
features[PipelineFeatureType.ACTION]["action.j2.pos"] = float
|
||||
return features
|
||||
|
||||
|
||||
def test_dataset_spec_action_with_patterns():
|
||||
"""Action features are filtered by pattern; unmatched keys are excluded."""
|
||||
features = {
|
||||
"action.ee.x": float,
|
||||
"action.ee.y": float,
|
||||
"action.j1.pos": float,
|
||||
"action.j2.pos": float,
|
||||
}
|
||||
out = _features_to_dataset_spec(
|
||||
features, is_action=True, use_videos=True, patterns=["action.j1.pos", "action.j2.pos"]
|
||||
@dataclass
|
||||
class AddObservationStateFeatures(ProcessorStep):
|
||||
"""Adds state features (and optionally an image spec to test precedence)."""
|
||||
|
||||
add_front_image: bool = False
|
||||
front_image_shape: tuple = (240, 320, 3)
|
||||
|
||||
def __call__(self, tr):
|
||||
return tr
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# State features (mix EE and a joint state)
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float
|
||||
if self.add_front_image:
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape
|
||||
return features
|
||||
|
||||
|
||||
def test_aggregate_joint_action_only():
|
||||
rp = DataProcessorPipeline([AddActionEEAndJointFeatures()])
|
||||
initial = {PipelineFeatureType.OBSERVATION: {"front": (480, 640, 3)}, PipelineFeatureType.ACTION: {}}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
use_videos=True,
|
||||
patterns=["action.j1.pos", "action.j2.pos"],
|
||||
)
|
||||
|
||||
assert ACTION in out
|
||||
# Expect only ACTION with joint names
|
||||
assert ACTION in out and OBS_STATE not in out
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
assert set(out[ACTION]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out[ACTION]["shape"] == (len(out[ACTION]["names"]),)
|
||||
assert OBS_STATE not in out
|
||||
|
||||
|
||||
def test_dataset_spec_action_and_observation_with_videos():
|
||||
"""EE action + state obs + image obs; all appear with correct dtypes."""
|
||||
action_features = {"action.ee.x": float, "action.ee.y": float}
|
||||
obs_features = {
|
||||
f"{OBS_STATE}.ee.x": float,
|
||||
f"{OBS_STATE}.j1.pos": float,
|
||||
"front": (480, 640, 3),
|
||||
"side": (720, 1280, 3),
|
||||
}
|
||||
def test_aggregate_ee_action_and_observation_with_videos():
|
||||
rp = DataProcessorPipeline([AddActionEEAndJointFeatures(), AddObservationStateFeatures()])
|
||||
initial = {"front": (480, 640, 3), "side": (720, 1280, 3)}
|
||||
|
||||
act_out = _features_to_dataset_spec(action_features, is_action=True, use_videos=False)
|
||||
obs_out = _features_to_dataset_spec(obs_features, is_action=False, use_videos=True)
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", OBS_STATE],
|
||||
)
|
||||
|
||||
assert ACTION in act_out
|
||||
assert set(act_out[ACTION]["names"]) == {"ee.x", "ee.y"}
|
||||
assert act_out[ACTION]["dtype"] == "float32"
|
||||
# Action should pack only EE names
|
||||
assert ACTION in out
|
||||
assert set(out[ACTION]["names"]) == {"ee.x", "ee.y"}
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
|
||||
assert OBS_STATE in obs_out
|
||||
assert set(obs_out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert obs_out[OBS_STATE]["dtype"] == "float32"
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert OBS_STATE in out
|
||||
assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out[OBS_STATE]["dtype"] == "float32"
|
||||
|
||||
for cam, shape in [("front", (480, 640, 3)), ("side", (720, 1280, 3))]:
|
||||
# Cameras from initial_features appear as videos
|
||||
for cam in ("front", "side"):
|
||||
key = f"{OBS_IMAGES}.{cam}"
|
||||
assert key in obs_out, f"missing camera key {key}"
|
||||
assert obs_out[key]["dtype"] == "video"
|
||||
assert obs_out[key]["shape"] == shape
|
||||
assert obs_out[key]["names"] == ["height", "width", "channels"]
|
||||
assert key in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key]["shape"] == initial[cam]
|
||||
assert out[key]["names"] == ["height", "width", "channels"]
|
||||
|
||||
|
||||
def test_dataset_spec_all_action_types():
|
||||
"""EE and joint action features are both included when no pattern filter."""
|
||||
features = {
|
||||
"action.ee.x": float,
|
||||
"action.ee.y": float,
|
||||
"action.j1.pos": float,
|
||||
"action.j2.pos": float,
|
||||
}
|
||||
out = _features_to_dataset_spec(features, is_action=True, use_videos=True, patterns=None)
|
||||
def test_aggregate_both_action_types():
|
||||
rp = DataProcessorPipeline([AddActionEEAndJointFeatures()])
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.j1", "action.j2.pos"],
|
||||
)
|
||||
|
||||
assert ACTION in out
|
||||
expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"}
|
||||
@@ -2109,40 +2143,58 @@ def test_dataset_spec_all_action_types():
|
||||
assert out[ACTION]["shape"] == (len(expected),)
|
||||
|
||||
|
||||
def test_dataset_spec_images_excluded_when_no_videos():
|
||||
"""Image observation features are dropped entirely when use_videos=False."""
|
||||
obs_features = {
|
||||
f"{OBS_STATE}.j1.pos": float,
|
||||
"back": (480, 640, 3),
|
||||
f"{OBS_IMAGES}.front": (240, 320, 3),
|
||||
}
|
||||
out = _features_to_dataset_spec(obs_features, is_action=False, use_videos=False)
|
||||
def test_aggregate_images_when_use_videos_false():
|
||||
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
|
||||
initial = {"back": (480, 640, 3)}
|
||||
|
||||
assert f"{OBS_IMAGES}.back" not in out
|
||||
assert f"{OBS_IMAGES}.front" not in out
|
||||
# Non-image state feature is still present
|
||||
assert OBS_STATE in out
|
||||
assert "j1.pos" in out[OBS_STATE]["names"]
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=False, # expect "image" dtype
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = f"{OBS_IMAGES}.back"
|
||||
key_front = f"{OBS_IMAGES}.front"
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
|
||||
|
||||
def test_dataset_spec_images_included_when_use_videos():
|
||||
"""Image features appear as video entries when use_videos=True."""
|
||||
obs_features = {
|
||||
"back": (480, 640, 3),
|
||||
f"{OBS_IMAGES}.front": (240, 320, 3),
|
||||
}
|
||||
out = _features_to_dataset_spec(obs_features, is_action=False, use_videos=True)
|
||||
def test_aggregate_images_when_use_videos_true():
|
||||
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
|
||||
initial = {"back": (480, 640, 3)}
|
||||
|
||||
assert f"{OBS_IMAGES}.back" in out
|
||||
assert out[f"{OBS_IMAGES}.back"]["dtype"] == "video"
|
||||
assert out[f"{OBS_IMAGES}.back"]["shape"] == (480, 640, 3)
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
assert f"{OBS_IMAGES}.front" in out
|
||||
assert out[f"{OBS_IMAGES}.front"]["dtype"] == "video"
|
||||
assert out[f"{OBS_IMAGES}.front"]["shape"] == (240, 320, 3)
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
key_back = f"{OBS_IMAGES}.back"
|
||||
assert key in out
|
||||
assert key_back in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key_back]["dtype"] == "video"
|
||||
assert out[key_back]["shape"] == initial["back"]
|
||||
|
||||
|
||||
def test_dataset_spec_empty_features_returns_empty():
|
||||
"""Empty feature dict returns an empty output dict."""
|
||||
assert _features_to_dataset_spec({}, is_action=True, use_videos=True) == {}
|
||||
assert _features_to_dataset_spec({}, is_action=False, use_videos=True) == {}
|
||||
def test_initial_camera_not_overridden_by_step_image():
|
||||
# Step explicitly sets a different front image shape; initial has another shape.
|
||||
# aggregate_pipeline_dataset_features should keep the step's value (setdefault behavior on initial cams).
|
||||
rp = DataProcessorPipeline(
|
||||
[AddObservationStateFeatures(add_front_image=True, front_image_shape=(240, 320, 3))]
|
||||
)
|
||||
initial = {"front": (480, 640, 3)} # should NOT override the step-provided (240, 320, 3)
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=True,
|
||||
patterns=[f"{OBS_IMAGES}.front"],
|
||||
)
|
||||
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
assert key in out
|
||||
assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Integration tests for loading robot/teleop pipelines from the Hugging Face Hub.
|
||||
|
||||
These tests require network access and are marked with ``@pytest.mark.integration``.
|
||||
Run with::
|
||||
|
||||
pytest tests/test_pipeline_hub.py -m integration -v
|
||||
|
||||
The tests verify the full end-to-end flow of:
|
||||
1. Loading a pipeline from the Hub via ``RobotProcessorPipeline.from_pretrained(...)``
|
||||
2. Attaching it to a robot or teleoperator via ``set_output_pipeline`` / ``set_input_pipeline``
|
||||
3. Verifying that ``observation_features`` / ``action_features`` differ from the raw versions
|
||||
|
||||
Note: The Hub repos referenced below are placeholders. Update them once actual pipelines
|
||||
are published to the Hub.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ─── Shared mock infrastructure (mirrors test_robot_pipeline.py) ──────────────
|
||||
|
||||
try:
|
||||
from tests.test_robot_pipeline import MockRobot, MockTeleop # type: ignore[import]
|
||||
except ImportError:
|
||||
# Fallback if tests are run from a different working directory
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_robot_pipeline import MockRobot, MockTeleop
|
||||
|
||||
|
||||
# ─── Integration tests ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_load_robot_pipeline_from_hub(tmp_path):
|
||||
"""
|
||||
Full end-to-end: load a FK observation pipeline for SO-101 from the Hub,
|
||||
attach it to a robot, and verify that observation_features are transformed.
|
||||
|
||||
Prerequisites:
|
||||
- A pipeline must be published at ``lerobot/so101-fk-observation-pipeline`` on the Hub.
|
||||
- A URDF file must be available locally (update ``local_urdf_path`` to point to it).
|
||||
"""
|
||||
pytest.importorskip("huggingface_hub")
|
||||
from lerobot.processor.pipeline import RobotProcessorPipeline
|
||||
|
||||
local_urdf_path = tmp_path / "so101.urdf"
|
||||
# NOTE: In a real test environment, provide an actual URDF or mock the kinematics.
|
||||
# For now, this test validates the Hub loading mechanism only if a URDF is provided.
|
||||
if not local_urdf_path.exists():
|
||||
pytest.skip("URDF not available; skipping Hub loading test")
|
||||
|
||||
pipeline = RobotProcessorPipeline.from_pretrained(
|
||||
"lerobot/so101-fk-observation-pipeline",
|
||||
overrides={"step_0": {"urdf_path": str(local_urdf_path)}},
|
||||
)
|
||||
robot = MockRobot()
|
||||
robot.set_output_pipeline(pipeline)
|
||||
|
||||
# Pipeline-transformed features should differ from raw features (EE vs joints)
|
||||
assert robot.observation_features != robot.raw_observation_features
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_load_teleop_pipeline_from_hub(tmp_path):
|
||||
"""
|
||||
Full end-to-end: load a FK action pipeline for SO-101 leader from the Hub,
|
||||
attach it to a teleoperator, and verify that action_features are transformed.
|
||||
|
||||
Prerequisites:
|
||||
- A pipeline must be published at ``lerobot/so101-leader-fk-action-pipeline`` on the Hub.
|
||||
- A URDF file must be available locally (update ``local_urdf_path`` to point to it).
|
||||
"""
|
||||
pytest.importorskip("huggingface_hub")
|
||||
from lerobot.processor.pipeline import RobotProcessorPipeline
|
||||
|
||||
local_urdf_path = tmp_path / "so101.urdf"
|
||||
if not local_urdf_path.exists():
|
||||
pytest.skip("URDF not available; skipping Hub loading test")
|
||||
|
||||
pipeline = RobotProcessorPipeline.from_pretrained(
|
||||
"lerobot/so101-leader-fk-action-pipeline",
|
||||
overrides={"step_0": {"urdf_path": str(local_urdf_path)}},
|
||||
)
|
||||
teleop = MockTeleop()
|
||||
teleop.set_output_pipeline(pipeline)
|
||||
|
||||
# Pipeline-transformed features should differ from raw features (EE vs joints)
|
||||
assert teleop.action_features != teleop.raw_action_features
|
||||
@@ -1,433 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit tests for the robot/teleoperator pipeline interface.
|
||||
|
||||
Tests cover:
|
||||
- Default (identity) pipeline behaviour
|
||||
- Custom pipeline attachment via set_output_pipeline / set_input_pipeline
|
||||
- Auto-derived observation_features / action_features via pipelines
|
||||
- Compatibility checks
|
||||
- build_dataset_features utility
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.factory import (
|
||||
_make_identity_feedback_pipeline,
|
||||
_make_identity_observation_pipeline,
|
||||
_make_identity_robot_action_pipeline,
|
||||
_make_identity_teleop_action_pipeline,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
IdentityProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
RobotActionProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.utils.pipeline_utils import (
|
||||
build_dataset_features,
|
||||
check_action_space_compatibility,
|
||||
check_observation_space_compatibility,
|
||||
)
|
||||
|
||||
|
||||
# ─── Mock hardware classes ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockRobotConfig:
|
||||
id: str = "mock_robot"
|
||||
calibration_dir: Path | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockTeleopConfig:
|
||||
id: str = "mock_teleop"
|
||||
calibration_dir: Path | None = None
|
||||
|
||||
|
||||
_JOINT_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||
_JOINT_FEATURES = {f"{j}.pos": float for j in _JOINT_NAMES}
|
||||
_EE_FEATURES = {"ee.x": float, "ee.y": float, "ee.z": float, "ee.wx": float, "ee.wy": float, "ee.wz": float, "ee.gripper_vel": float}
|
||||
|
||||
|
||||
class MockRobot(Robot):
|
||||
"""Minimal Robot that stores last action for assertion."""
|
||||
|
||||
config_class = MockRobotConfig
|
||||
name = "mock_robot"
|
||||
|
||||
def __init__(self):
|
||||
# bypass filesystem calibration setup; initialize with identity pipelines directly
|
||||
self._output_pipeline = _make_identity_observation_pipeline()
|
||||
self._input_pipeline = _make_identity_robot_action_pipeline()
|
||||
self._last_raw_obs: RobotObservation = {}
|
||||
self._last_sent: RobotAction = {}
|
||||
|
||||
@property
|
||||
def raw_observation_features(self) -> dict:
|
||||
return {**_JOINT_FEATURES, "camera": (480, 640, 3)}
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
return _JOINT_FEATURES
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return True
|
||||
|
||||
def connect(self, calibrate=True):
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self):
|
||||
pass
|
||||
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
return {f"{j}.pos": float(i) for i, j in enumerate(_JOINT_NAMES)} | {"camera": None}
|
||||
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
self._last_sent = action
|
||||
return action
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockTeleop(Teleoperator):
|
||||
"""Minimal Teleoperator."""
|
||||
|
||||
config_class = MockTeleopConfig
|
||||
name = "mock_teleop"
|
||||
|
||||
def __init__(self):
|
||||
# bypass filesystem calibration setup; initialize with identity pipelines directly
|
||||
self._output_pipeline = _make_identity_teleop_action_pipeline()
|
||||
self._input_pipeline = _make_identity_feedback_pipeline()
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict:
|
||||
return _JOINT_FEATURES
|
||||
|
||||
@property
|
||||
def raw_feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return True
|
||||
|
||||
def connect(self, calibrate=True):
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self):
|
||||
pass
|
||||
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def _get_action(self) -> RobotAction:
|
||||
return {f"{j}.pos": float(i) for i, j in enumerate(_JOINT_NAMES)}
|
||||
|
||||
def _send_feedback(self, feedback):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
# ─── Simple transform step (doubles all float values) ────────────────────────
|
||||
|
||||
|
||||
class DoubleActionStep(RobotActionProcessorStep):
|
||||
"""Doubles all float action values."""
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
return {k: v * 2 for k, v in action.items()}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
class RenameToEEObsStep(ObservationProcessorStep):
|
||||
"""Renames joint obs keys to EE-like keys for testing transform_features."""
|
||||
|
||||
def observation(self, obs: RobotObservation) -> RobotObservation:
|
||||
return {f"ee.{i}": v for i, v in enumerate(obs.values()) if isinstance(v, float)}
|
||||
|
||||
def transform_features(self, features):
|
||||
obs = features.get(PipelineFeatureType.OBSERVATION, {})
|
||||
new_obs = {f"ee.{i}": float for i in range(len([v for v in obs.values() if v == float]))}
|
||||
return {**features, PipelineFeatureType.OBSERVATION: new_obs}
|
||||
|
||||
|
||||
# ─── Tests: Robot pipeline interface ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_robot_default_pipeline_is_identity():
|
||||
"""With no custom pipeline, get_observation returns the same as _get_observation."""
|
||||
robot = MockRobot()
|
||||
raw = robot._get_observation()
|
||||
obs = robot.get_observation()
|
||||
assert obs == raw
|
||||
|
||||
|
||||
def test_robot_observation_caches_last_raw():
|
||||
"""get_observation caches raw result for IK use in send_action."""
|
||||
robot = MockRobot()
|
||||
robot.get_observation()
|
||||
assert robot._last_raw_obs is not None
|
||||
assert "shoulder_pan.pos" in robot._last_raw_obs
|
||||
|
||||
|
||||
def test_robot_default_send_action_is_identity():
|
||||
"""With no custom pipeline, send_action passes action unchanged to _send_action."""
|
||||
robot = MockRobot()
|
||||
robot.get_observation() # populate _last_raw_obs
|
||||
action = {f"{j}.pos": 1.0 for j in _JOINT_NAMES}
|
||||
sent = robot.send_action(action)
|
||||
assert sent == action
|
||||
assert robot._last_sent == action
|
||||
|
||||
|
||||
def test_robot_custom_output_pipeline_applied():
|
||||
"""A custom action pipeline is applied to the action before _send_action."""
|
||||
robot = MockRobot()
|
||||
double_pipeline = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[DoubleActionStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot.set_input_pipeline(double_pipeline)
|
||||
robot.get_observation() # populate _last_raw_obs
|
||||
action = {f"{j}.pos": 1.0 for j in _JOINT_NAMES}
|
||||
robot.send_action(action)
|
||||
assert all(v == 2.0 for v in robot._last_sent.values())
|
||||
|
||||
|
||||
def test_robot_observation_features_identity_matches_raw():
|
||||
"""observation_features equals raw_observation_features with identity pipeline."""
|
||||
robot = MockRobot()
|
||||
assert robot.observation_features == robot.raw_observation_features
|
||||
|
||||
|
||||
def test_robot_raw_observation_features_unchanged_after_pipeline():
|
||||
"""raw_observation_features is unaffected by the output pipeline."""
|
||||
robot = MockRobot()
|
||||
# Even with an FK-like renaming pipeline, raw_observation_features stays the same
|
||||
transform_pipeline = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[RenameToEEObsStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
robot.set_output_pipeline(transform_pipeline)
|
||||
# raw should still be joints + camera
|
||||
raw = robot.raw_observation_features
|
||||
assert "shoulder_pan.pos" in raw
|
||||
assert "camera" in raw
|
||||
|
||||
|
||||
def test_robot_action_features_identity_matches_raw():
|
||||
"""action_features equals raw_action_features with identity input pipeline."""
|
||||
robot = MockRobot()
|
||||
assert robot.action_features == robot.raw_action_features
|
||||
|
||||
|
||||
def test_robot_raw_action_features_unchanged_after_pipeline():
|
||||
"""raw_action_features is unaffected by any pipeline."""
|
||||
robot = MockRobot()
|
||||
double_pipeline = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[DoubleActionStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot.set_input_pipeline(double_pipeline)
|
||||
assert robot.raw_action_features == _JOINT_FEATURES
|
||||
|
||||
|
||||
def test_robot_set_output_pipeline_replaces_identity():
|
||||
"""set_output_pipeline replaces the default identity."""
|
||||
robot = MockRobot()
|
||||
p = _make_identity_observation_pipeline()
|
||||
robot.set_output_pipeline(p)
|
||||
assert robot._output_pipeline is p
|
||||
|
||||
|
||||
def test_robot_set_input_pipeline_replaces_identity():
|
||||
robot = MockRobot()
|
||||
p = _make_identity_robot_action_pipeline()
|
||||
robot.set_input_pipeline(p)
|
||||
assert robot._input_pipeline is p
|
||||
|
||||
|
||||
# ─── Tests: Teleoperator pipeline interface ───────────────────────────────────
|
||||
|
||||
|
||||
def test_teleop_default_get_action_is_identity():
|
||||
"""With no custom pipeline, get_action returns the same as _get_action."""
|
||||
teleop = MockTeleop()
|
||||
raw = teleop._get_action()
|
||||
action = teleop.get_action()
|
||||
assert action == raw
|
||||
|
||||
|
||||
def test_teleop_action_features_identity_matches_raw():
|
||||
"""action_features equals raw_action_features with identity pipeline."""
|
||||
teleop = MockTeleop()
|
||||
assert teleop.action_features == teleop.raw_action_features
|
||||
|
||||
|
||||
def test_teleop_feedback_features_identity_matches_raw():
|
||||
"""feedback_features equals raw_feedback_features with identity input pipeline."""
|
||||
teleop = MockTeleop()
|
||||
assert teleop.feedback_features == teleop.raw_feedback_features
|
||||
|
||||
|
||||
def test_teleop_feedback_features_empty_when_raw_empty():
|
||||
"""feedback_features returns empty dict when raw_feedback_features is empty."""
|
||||
teleop = MockTeleop()
|
||||
assert teleop.feedback_features == {}
|
||||
|
||||
|
||||
def test_teleop_set_output_pipeline():
|
||||
teleop = MockTeleop()
|
||||
p = _make_identity_teleop_action_pipeline()
|
||||
teleop.set_output_pipeline(p)
|
||||
assert teleop._output_pipeline is p
|
||||
|
||||
|
||||
def test_teleop_send_feedback_calls_send_feedback_impl():
|
||||
"""send_feedback applies identity pipeline and delegates to _send_feedback."""
|
||||
teleop = MockTeleop()
|
||||
received = {}
|
||||
|
||||
def capture(fb):
|
||||
received.update(fb)
|
||||
|
||||
teleop._send_feedback = capture
|
||||
teleop.send_feedback({"key": 1.0})
|
||||
assert received == {"key": 1.0}
|
||||
|
||||
|
||||
# ─── Tests: Compatibility checks ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_check_action_space_compatibility_matching():
|
||||
"""No warning when teleop output and robot action features match."""
|
||||
teleop = MockTeleop()
|
||||
robot = MockRobot()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
check_action_space_compatibility(teleop, robot) # should not warn
|
||||
|
||||
|
||||
def test_check_action_space_compatibility_mismatch_warns():
|
||||
"""Warning issued when teleop and robot action features differ."""
|
||||
|
||||
class EETeleop(MockTeleop):
|
||||
@property
|
||||
def raw_action_features(self):
|
||||
return _EE_FEATURES
|
||||
|
||||
teleop = EETeleop()
|
||||
robot = MockRobot() # still returns joint features
|
||||
with pytest.warns(UserWarning, match="Action space mismatch"):
|
||||
check_action_space_compatibility(teleop, robot)
|
||||
|
||||
|
||||
def test_check_observation_space_compatibility_no_feedback():
|
||||
"""No warning when teleop has empty feedback_features."""
|
||||
robot = MockRobot()
|
||||
teleop = MockTeleop()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
check_observation_space_compatibility(robot, teleop) # empty feedback → no warning
|
||||
|
||||
|
||||
# ─── Tests: build_dataset_features ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_build_dataset_features_identity():
|
||||
"""With identity pipelines, dataset features contain joint keys."""
|
||||
robot = MockRobot()
|
||||
teleop = MockTeleop()
|
||||
features = build_dataset_features(robot, teleop, use_videos=False)
|
||||
# Should contain action features (joint names)
|
||||
action_keys = {k for k in features if "action" in k or any(j in k for j in _JOINT_NAMES)}
|
||||
assert len(action_keys) > 0
|
||||
|
||||
|
||||
def test_build_dataset_features_includes_images_when_use_videos_true():
|
||||
"""Image features are included when use_videos=True."""
|
||||
robot = MockRobot()
|
||||
teleop = MockTeleop()
|
||||
feats_with = build_dataset_features(robot, teleop, use_videos=True)
|
||||
feats_without = build_dataset_features(robot, teleop, use_videos=False)
|
||||
# With videos should have more features (camera)
|
||||
assert len(feats_with) >= len(feats_without)
|
||||
|
||||
|
||||
# ─── Tests: Factory identity pipeline helpers ─────────────────────────────────
|
||||
|
||||
|
||||
def test_make_identity_observation_pipeline_is_noop():
|
||||
pipeline = _make_identity_observation_pipeline()
|
||||
obs = {"shoulder_pan.pos": 1.0, "camera": None}
|
||||
result = pipeline(obs)
|
||||
assert result == obs
|
||||
|
||||
|
||||
def test_make_identity_robot_action_pipeline_is_noop():
|
||||
pipeline = _make_identity_robot_action_pipeline()
|
||||
action = {"shoulder_pan.pos": 1.0}
|
||||
obs = {"shoulder_pan.pos": 0.0}
|
||||
result = pipeline((action, obs))
|
||||
assert result == action
|
||||
|
||||
|
||||
def test_make_identity_teleop_action_pipeline_is_noop():
|
||||
pipeline = _make_identity_teleop_action_pipeline()
|
||||
action = {"shoulder_pan.pos": 1.0}
|
||||
result = pipeline(action)
|
||||
assert result == action
|
||||
@@ -24,11 +24,6 @@ def mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int):
|
||||
self.num_processes = num_processes
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
meter = AverageMeter("loss", ":.2f")
|
||||
assert meter.name == "loss"
|
||||
@@ -87,37 +82,6 @@ def test_metrics_tracker_step(mock_metrics):
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=10,
|
||||
accelerator=MockAccelerator(num_processes=2),
|
||||
)
|
||||
assert tracker.steps == 10
|
||||
assert tracker.samples == 10 * 32 * 2
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_step_with_accelerator(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=5,
|
||||
accelerator=MockAccelerator(num_processes=2),
|
||||
)
|
||||
tracker.step()
|
||||
assert tracker.steps == 6
|
||||
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
|
||||
Reference in New Issue
Block a user