Package folder structure (#1417)

* Move files

* Replace imports & paths

* Update relative paths

* Update doc symlinks

* Update instructions paths

* Fix imports

* Update grpc files

* Update more instructions

* Downgrade grpc-tools

* Update manifest

* Update more paths

* Update config paths

* Update CI paths

* Update bandit exclusions

* Remove walkthrough section
This commit is contained in:
Simon Alibert
2025-07-01 16:34:46 +02:00
committed by GitHub
parent 483be9aac2
commit d4ee470b00
268 changed files with 862 additions and 890 deletions

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Use this script to get a quick summary of your system config.
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
"""
import platform
HAS_HF_HUB = True
HAS_HF_DATASETS = True
HAS_NP = True
HAS_TORCH = True
HAS_LEROBOT = True
try:
import huggingface_hub
except ImportError:
HAS_HF_HUB = False
try:
import datasets
except ImportError:
HAS_HF_DATASETS = False
try:
import numpy as np
except ImportError:
HAS_NP = False
try:
import torch
except ImportError:
HAS_TORCH = False
try:
import lerobot
except ImportError:
HAS_LEROBOT = False
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
# TODO(aliberts): refactor into an actual command `lerobot env`
def display_sys_info() -> dict:
"""Run this to get basic system info to help for tracking issues & bugs."""
info = {
"`lerobot` version": lerobot_version,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"Huggingface_hub version": hf_hub_version,
"Dataset version": hf_datasets_version,
"Numpy version": np_version,
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
"Cuda version": cuda_version,
"Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(format_dict(info))
return info
def format_dict(d: dict) -> str:
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
if __name__ == "__main__":
display_sys_info()

506
src/lerobot/scripts/eval.py Normal file
View File

@@ -0,0 +1,506 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a policy on an environment by running rollouts and computing metrics.
Usage examples:
You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht)
for 10 episodes.
```
python -m lerobot.scripts.eval \
--policy.path=lerobot/diffusion_pusht \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
```
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
```
python -m lerobot.scripts.eval \
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
```
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
import json
import logging
import threading
import time
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Callable
import einops
import gymnasium as gym
import numpy as np
import torch
from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
inside_slurm,
)
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
Note that all environments in the batch are run until the last environment is done. This means some
data will probably need to be discarded (for environments that aren't the first one to be done).
The return dictionary contains:
(optional) "observation": A dictionary of (batch, sequence + 1, *) tensors mapped to observation
keys. NOTE that this has an extra sequence element relative to the other keys in the
dictionary. This is because an extra observation is included for after the environment is
terminated or truncated.
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
including the last observations).
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation).
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above.
Args:
env: The batch of environments.
policy: The policy. Must be a PyTorch nn module.
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
specifies the seeds for each of the environments.
return_observations: Whether to include all observations in the returned rollout data. Observations
are returned optionally because they typically take more memory to cache. Defaults to False.
render_callback: Optional rendering callback to be used after the environments are reset, and after
every step.
Returns:
The dictionary described above.
"""
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
device = get_device_from_parameters(policy)
# Reset the policy and environments.
policy.reset()
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
all_observations = []
all_actions = []
all_rewards = []
all_successes = []
all_dones = []
step = 0
# Keep track of which environments are done.
done = np.array([False] * env.num_envs)
max_steps = env.call("_max_episode_steps")[0]
progbar = trange(
max_steps,
desc=f"Running rollout with at most {max_steps} steps",
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
leave=False,
)
check_env_attributes_and_types(env)
while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
with torch.inference_mode():
action = policy.select_action(observation)
# Convert to CPU / numpy.
action = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action)
if render_callback is not None:
render_callback(env)
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
else:
successes = [False] * env.num_envs
# Keep track of which environments are done so far.
done = terminated | truncated | done
all_actions.append(torch.from_numpy(action))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
if return_observations:
observation = preprocess_observation(observation)
all_observations.append(deepcopy(observation))
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
"action": torch.stack(all_actions, dim=1),
"reward": torch.stack(all_rewards, dim=1),
"success": torch.stack(all_successes, dim=1),
"done": torch.stack(all_dones, dim=1),
}
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"):
policy.use_original_modules()
return ret
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
) -> dict:
"""
Args:
env: The batch of environments.
policy: The policy.
n_episodes: The number of episodes to evaluate.
max_episodes_rendered: Maximum number of episodes to render into videos.
videos_dir: Where to save rendered videos.
return_episode_data: Whether to return episode data for online training. Incorporates the data into
the "episodes" key of the returned dictionary.
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
seed is incremented by 1. If not provided, the environments are not manually seeded.
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
if not isinstance(policy, PreTrainedPolicy):
raise ValueError(
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
)
start = time.time()
policy.eval()
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
# divisible by env.num_envs we end up discarding some data in the last batch.
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
# Keep track of some metrics.
sum_rewards = []
max_rewards = []
all_successes = []
all_seeds = []
threads = [] # for video saving threads
n_episodes_rendered = 0 # for saving the correct number of videos
# Callback for visualization.
def render_frame(env: gym.vector.VectorEnv):
# noqa: B023
if n_episodes_rendered >= max_episodes_rendered:
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
if max_episodes_rendered > 0:
video_paths: list[str] = []
if return_episode_data:
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []
if start_seed is None:
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env,
policy,
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
n_steps = rollout_data["done"].shape[1]
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
else:
all_seeds.append(None)
# FIXME: episode_data is either None or it doesn't exist
if return_episode_data:
this_episode_data = _compile_episode_data(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *)
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if n_episodes_rendered >= max_episodes_rendered:
break
videos_dir.mkdir(parents=True, exist_ok=True)
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=write_video,
args=(
str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
thread.start()
threads.append(thread)
n_episodes_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
)
# Wait till all video rendering threads are done.
for thread in threads:
thread.join()
# Compile eval info.
info = {
"per_episode": [
{
"episode_ix": i,
"sum_reward": sum_reward,
"max_reward": max_reward,
"success": success,
"seed": seed,
}
for i, (sum_reward, max_reward, success, seed) in enumerate(
zip(
sum_rewards[:n_episodes],
max_rewards[:n_episodes],
all_successes[:n_episodes],
all_seeds[:n_episodes],
strict=True,
)
)
],
"aggregated": {
"avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])),
"avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])),
"pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100),
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / n_episodes,
},
}
if return_episode_data:
info["episodes"] = episode_data
if max_episodes_rendered > 0:
info["video_paths"] = video_paths
return info
def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
Compiles all the rollout data into a Hugging Face dataset.
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
"""
ep_dicts = []
total_frames = 0
for ep_ix in range(rollout_data["action"].shape[0]):
# + 2 to include the first done frame and the last observation frame.
num_frames = done_indices[ep_ix].item() + 2
total_frames += num_frames
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
}
# For the last observation frame, all other keys will just be copy padded.
for k in ep_dict:
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
for key in rollout_data["observation"]:
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
ep_dicts.append(ep_dict)
data_dict = {}
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
return data_dict
@parser.wrap()
def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
)
print(info["aggregated"])
# Save info
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
env.close()
logging.info("End of eval")
if __name__ == "__main__":
init_logging()
eval_main()

View File

@@ -0,0 +1,118 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple script to control a robot from teleoperation.
Example:
```shell
python -m lerobot.scripts.server.find_joint_limits \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
```
"""
import time
from dataclasses import dataclass
import draccus
import numpy as np
from lerobot.model.kinematics import RobotKinematics
from lerobot.robots import ( # noqa: F401
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
gamepad,
koch_leader,
make_teleoperator_from_config,
so100_leader,
)
@dataclass
class FindJointLimitsConfig:
teleop: TeleoperatorConfig
robot: RobotConfig
# Limit the maximum frames per second. By default, no limit.
teleop_time_s: float = 30
# Display all cameras on screen
display_data: bool = False
@draccus.wrap()
def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
teleop = make_teleoperator_from_config(cfg.teleop)
robot = make_robot_from_config(cfg.robot)
teleop.connect()
robot.connect()
start_episode_t = time.perf_counter()
robot_type = getattr(robot.config, "robot_type", "so101")
if "so100" in robot_type or "so101" in robot_type:
# Note to be compatible with the rest of the codebase,
# we are using the new calibration method for so101 and so100
robot_type = "so_new_calibration"
kinematics = RobotKinematics(robot_type=robot_type)
# Initialize min/max values
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3]
max_pos = joint_positions.copy()
min_pos = joint_positions.copy()
max_ee = ee_pos.copy()
min_ee = ee_pos.copy()
while True:
action = teleop.get_action()
robot.send_action(action)
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3]
# Skip initial warmup period
if (time.perf_counter() - start_episode_t) < 5:
continue
# Update min/max values
max_ee = np.maximum(max_ee, ee_pos)
min_ee = np.minimum(min_ee, ee_pos)
max_pos = np.maximum(max_pos, joint_positions)
min_pos = np.minimum(min_pos, joint_positions)
if time.perf_counter() - start_episode_t > cfg.teleop_time_s:
print(f"Max ee position {np.round(max_ee, 4).tolist()}")
print(f"Min ee position {np.round(min_ee, 4).tolist()}")
print(f"Max joint pos position {np.round(max_pos, 4).tolist()}")
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
break
if __name__ == "__main__":
find_joint_and_ee_bounds()

View File

@@ -0,0 +1,709 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor server runner for distributed HILSerl robot policy training.
This script implements the actor component of the distributed HILSerl architecture.
It executes the policy in the robot environment, collects experience,
and sends transitions to the learner server for policy updates.
Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
server is started before launching the actor.
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
reduce interventions as the policy improves.
**WORKFLOW**:
1. Determine robot workspace bounds using `find_joint_limits.py`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
5. Start this actor server with the same configuration
6. Use human interventions to guide policy learning
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
import time
from functools import lru_cache
from queue import Empty
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl import learner_service
from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.queue import get_last_item_from_queue
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
)
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
@parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig):
cfg.validate()
display_pid = False
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}")
is_threaded = use_threads(cfg)
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")
if not establish_learner_connection(learner_client, shutdown_event):
logging.error("[ACTOR] Failed to establish connection with Learner")
return
if not use_threads(cfg):
# If we use multithreading, we can reuse the channel
grpc_channel.close()
grpc_channel = None
logging.info("[ACTOR] Connection with Learner established")
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from multiprocessing import Process
concurrency_entity = Process
receive_policy_process = concurrency_entity(
target=receive_policy,
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process = concurrency_entity(
target=send_transitions,
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
daemon=True,
)
interactions_process = concurrency_entity(
target=send_interactions,
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process.start()
interactions_process.start()
receive_policy_process.start()
act_with_policy(
cfg=cfg,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
transitions_queue=transitions_queue,
interactions_queue=interactions_queue,
)
logging.info("[ACTOR] Policy process joined")
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
transitions_process.join()
logging.info("[ACTOR] Transitions process joined")
interactions_process.join()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
#################################################
# Core algorithm functions #
#################################################
def act_with_policy(
cfg: TrainRLServerPipelineConfig,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg: Configuration settings for the interaction process.
shutdown_event: Event to check if the process should shutdown.
parameters_queue: Queue to receive updated network parameters from the learner.
transitions_queue: Queue to send transitions to the learner.
interactions_queue: Queue to send interactions to the learner.
"""
# Initialize logging for multiprocessing
if not use_threads(cfg):
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor policy process logging initialized")
logging.info("make_env online")
online_env = make_robot_env(cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
policy_timer = TimerManager("Policy inference", log=False)
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
transitions=list_transition_to_send_to_learner,
transitions_queue=transitions_queue,
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(policy_timer)
policy_timer.reset()
# Calculate intervention rate
intervention_rate = 0.0
if episode_total_steps > 0:
intervention_rate = episode_intervention_steps / episode_total_steps
# Send episodic reward to the learner
interactions_queue.put(
python_object_to_bytes(
{
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
"Intervention rate": intervention_rate,
**stats,
}
)
)
# Reset intervention counters
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.env.fps - dt_time)
#################################################
# Communication Functions - Group all gRPC/messaging functions #
#################################################
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
return False
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
time.sleep(2)
return False
@lru_cache(maxsize=1)
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
)
stub = services_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
return stub, channel
def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
"""Receive parameters from the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
iterator = learner_client.StreamParameters(services_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Received policy loop stopped")
def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes:
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized")
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendTransitions(
transitions_stream(
shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout
)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Transitions process stopped")
def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor interactions process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendInteractions(
interactions_stream(
shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout
)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming interactions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
except Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
yield from send_bytes_in_chunks(
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return services_pb2.Empty()
def interactions_stream(
shutdown_event: Event,
interactions_queue: Queue,
timeout: float, # type: ignore
) -> services_pb2.Empty:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
except Empty:
logging.debug("[ACTOR] Interaction queue is empty")
continue
yield from send_bytes_in_chunks(
message,
services_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return services_pb2.Empty()
#################################################
# Policy functions #
#################################################
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
#################################################
# Utilities functions #
#################################################
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
transition_to_send_to_learner = []
for transition in transitions:
tr = move_transition_to_device(transition=transition, device="cpu")
for key, value in tr["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
transition_to_send_to_learner.append(tr)
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
timer (TimerManager): The timer with collected metrics.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
if timer.count > 1:
avg_fps = timer.fps_avg
p90_fps = timer.fps_percentile(90)
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
stats = {
"Policy frequency [Hz]": avg_fps,
"Policy frequency 90th-p [Hz]": p90_fps,
}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int):
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.actor == "threads"
if __name__ == "__main__":
actor_cli()

View File

@@ -0,0 +1,314 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from copy import deepcopy
from pathlib import Path
from typing import Dict, Tuple
import cv2
# import torch.nn.functional as F # noqa: N812
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
index_x, index_y = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal index_x, index_y, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
index_x, index_y = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
push_to_hub: bool = False,
task: str = "",
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx]
# Create a copy of the frame to add to the new dataset
new_frame = {}
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
continue
if key in ("next.done", "next.reward"):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)
if key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
# Apply crop then resize.
cropped = F.crop(value, top, left, height, width)
value = F.resize(cropped, resize_size)
value = value.clamp(0, 1)
new_frame[key] = value
new_dataset.add_frame(new_frame, task=task)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
new_dataset.save_episode()
prev_episode_index = frame["episode_index"].item()
# Save the last episode
new_dataset.save_episode()
if push_to_hub:
new_dataset.push_to_hub()
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
parser.add_argument(
"--push-to-hub",
type=bool,
default=False,
help="Whether to push the new dataset to the hub.",
)
parser.add_argument(
"--task",
type=str,
default="",
help="The natural language task to describe the dataset.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path) as f:
rois = json.load(f)
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
push_to_hub=args.push_to_hub,
task=args.task,
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)

View File

@@ -0,0 +1,74 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
so100_follower,
)
from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import (
gamepad, # noqa: F401
so101_leader, # noqa: F401
)
logging.basicConfig(level=logging.INFO)
def eval_policy(env, policy, n_episodes):
sum_reward_episode = []
for _ in range(n_episodes):
obs, _ = env.reset()
episode_reward = 0.0
while True:
action = policy.select_action(obs)
obs, reward, terminated, truncated, _ = env.step(action)
episode_reward += reward
if terminated or truncated:
break
sum_reward_episode.append(episode_reward)
logging.info(f"Success after 20 steps {sum_reward_episode}")
logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}")
@parser.wrap()
def main(cfg: TrainRLServerPipelineConfig):
env_cfg = cfg.env
env = make_robot_env(env_cfg)
dataset_cfg = cfg.dataset
dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id)
dataset_meta = dataset.meta
policy = make_policy(
cfg=cfg.policy,
# env_cfg=cfg.env,
ds_meta=dataset_meta,
)
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
policy.eval()
eval_policy(env, policy=policy, n_episodes=10)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,118 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from multiprocessing import Event, Queue
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.queue import get_last_item_from_queue
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
check transport.proto for the gRPC service definition
"""
def __init__(
self,
shutdown_event: Event, # type: ignore
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: Queue,
interaction_message_queue: Queue,
queue_get_timeout: float = 0.001,
):
self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
last_push_time = 0
while not self.shutdown_event.is_set():
time_since_last_push = time.time() - last_push_time
if time_since_last_push < self.seconds_between_pushes:
self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push)
# Continue, because we could receive a shutdown event,
# and it's checked in the while loop
continue
logging.info("[LEARNER] Push parameters to the Actor")
buffer = get_last_item_from_queue(
self.parameters_queue, block=True, timeout=self.queue_get_timeout
)
if buffer is None:
continue
yield from send_bytes_in_chunks(
buffer,
services_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
last_push_time = time.time()
logging.info("[LEARNER] Parameters sent")
logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802
return services_pb2.Empty()

View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.scripts.eval import eval_policy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
has_method,
init_logging,
)
from lerobot.utils.wandb_utils import WandBLogger
def update_policy(
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
eval_env.close()
logging.info("End of training")
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
if __name__ == "__main__":
init_logging()
train()

View File

@@ -0,0 +1,292 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesn't always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossy compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Examples:
- Visualize data stored on a local machine:
```
local$ python -m lerobot.scripts.visualize_dataset \
--repo-id lerobot/pusht \
--episode-index 0
```
- Visualize data stored on a distant machine with a local viewer:
```
distant$ python -m lerobot.scripts.visualize_dataset \
--repo-id lerobot/pusht \
--episode-index 0 \
--save 1 \
--output-dir path/to/directory
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python -m lerobot.scripts.visualize_dataset \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--ws-port 9087
local$ rerun ws://localhost:9087
```
"""
import argparse
import gc
import logging
import time
from pathlib import Path
from typing import Iterator
import numpy as np
import rerun as rr
import torch
import torch.utils.data
import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
) -> Path | None:
if save:
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
)
repo_id = dataset.repo_id
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
logging.info("Starting Rerun")
if mode not in ["local", "distant"]:
raise ValueError(mode)
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun")
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.meta.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
if mode == "local" and save:
# save .rrd locally
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
rr.save(rrd_path)
return rrd_path
elif mode == "distant":
# stop the process from exiting since it is serving the websocket connection
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
)
parser.add_argument(
"--episode-index",
type=int,
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of processes of Dataloader for loading the data.",
)
parser.add_argument(
"--mode",
type=str,
default="local",
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. "
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
"--web-port",
type=int,
default=9090,
help="Web port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--ws-port",
type=int,
default=9087,
help="Web socket port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--save",
type=int,
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. "
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
),
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
visualize_dataset(dataset, **vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,482 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Example of usage:
- Visualize data stored on a local machine:
```bash
local$ python -m lerobot.scripts.visualize_dataset_html \
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```bash
distant$ python -m lerobot.scripts.visualize_dataset_html \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```bash
python -m lerobot.scripts.visualize_dataset_html \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
```
"""
import argparse
import csv
import json
import logging
import re
import shutil
import tempfile
from io import StringIO
from pathlib import Path
import numpy as np
import pandas as pd
import requests
from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import IterableNamespace
from lerobot.utils.utils import init_logging
def run_server(
dataset: LeRobotDataset | IterableNamespace | None,
episodes: list[int] | None,
host: str,
port: str,
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def hommepage(dataset=dataset):
if dataset:
dataset_namespace, dataset_name = dataset.repo_id.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)
dataset_param, episode_param = None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)
featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
def show_first_episode(dataset_namespace, dataset_name):
first_episode_id = 0
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=first_episode_id,
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
dataset = get_dataset_info(repo_id)
except FileNotFoundError:
return (
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
400,
)
dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
major_version = int(match.group(1))
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
if isinstance(dataset, LeRobotDataset)
else dataset.total_frames,
"num_episodes": dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes,
"fps": dataset.fps,
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
]
videos_info = [
{
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
"filename": video_path.parent.name,
}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.video_path.format(
episode_chunk=int(episode_id) // dataset.chunks_size,
video_key=video_key,
episode_index=episode_id,
),
"filename": video_key,
}
for video_key in video_keys
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
)
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
ignored_columns=ignored_columns,
)
app.run(host=host, port=port)
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
ignored_columns = []
for column_name in selected_columns:
shape = dataset.features[column_name]["shape"]
shape_dim = len(shape)
if shape_dim > 1:
selected_columns.remove(column_name)
ignored_columns.append(column_name)
# init header of csv with state and action names
header = ["timestamp"]
for column_name in selected_columns:
dim_state = (
dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
header += column_names
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("pandas")
)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# Convert data to CSV string
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
csv_writer.writerow(header)
# Write data rows
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns, ignored_columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
return IterableNamespace(dataset_info)
def visualize_dataset_html(
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
init_logging()
template_dir = Path(__file__).resolve().parent.parent / "templates"
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
if dataset is None:
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
# Create a simlink from the dataset video folder containing mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,130 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize effects of image transforms for a given configuration.
This script will generate examples of transformed images as they are output by LeRobot dataset.
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
Example:
```bash
python -m lerobot.scripts.visualize_image_transforms \
--repo_id=lerobot/pusht \
--episodes='[0]' \
--image_transforms.enable=True
```
"""
import logging
from copy import deepcopy
from dataclasses import replace
from pathlib import Path
import draccus
from torchvision.transforms import ToPILImage
from lerobot.configs.default import DatasetConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.transforms import (
ImageTransforms,
ImageTransformsConfig,
make_transform_from_config,
)
OUTPUT_DIR = Path("outputs/image_transforms")
to_pil = ToPILImage()
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)
tfs = ImageTransforms(cfg)
for i in range(1, n_examples + 1):
transformed_frame = tfs(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
print("Combined transforms examples saved to:")
print(f" {output_dir_all}")
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
if not cfg.enable:
logging.warning(
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
)
return
print("Individual transforms examples saved to:")
for tf_name, tf_cfg in cfg.tfs.items():
# Apply a few transformation with random value in min_max range
output_dir_single = output_dir / tf_name
output_dir_single.mkdir(parents=True, exist_ok=True)
tf = make_transform_from_config(tf_cfg)
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
# Apply min, max, average transformations
tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
for key, (min_, max_) in tf_cfg.kwargs.items():
avg = (min_ + max_) / 2
tf_cfg_kwgs_min[key] = [min_, min_]
tf_cfg_kwgs_max[key] = [max_, max_]
tf_cfg_kwgs_avg[key] = [avg, avg]
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
tf_frame_min = tf_min(original_frame)
tf_frame_max = tf_max(original_frame)
tf_frame_avg = tf_avg(original_frame)
to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
print(f" {output_dir_single}")
@draccus.wrap()
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,
revision=cfg.revision,
video_backend=cfg.video_backend,
)
output_dir = output_dir / cfg.repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
original_frame = dataset[0][dataset.meta.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")
save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
if __name__ == "__main__":
visualize_image_transforms()