mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
Compare commits
147 Commits
feat/profi
...
feat/conve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42d35e47b2 | ||
|
|
1b9330a25a | ||
|
|
490ffa89a5 | ||
|
|
3d31f2ad53 | ||
|
|
af79dda8d9 | ||
|
|
952f455446 | ||
|
|
0747afdba7 | ||
|
|
992fb177c3 | ||
|
|
1db3401159 | ||
|
|
7868df27dc | ||
|
|
0e04f5fbbe | ||
|
|
fdccf7774b | ||
|
|
2a3d62259e | ||
|
|
2df4e25558 | ||
|
|
4062d0564a | ||
|
|
0a30636fc6 | ||
|
|
adad3698e1 | ||
|
|
84ffc28854 | ||
|
|
47aee1fdbe | ||
|
|
bbd64b9ce5 | ||
|
|
000e88760d | ||
|
|
35f36e8fba | ||
|
|
213ffe02cf | ||
|
|
2b03dec01f | ||
|
|
64a9dd3763 | ||
|
|
2ca6edc19e | ||
|
|
db36f01e8b | ||
|
|
c7a3b02625 | ||
|
|
267a753eda | ||
|
|
4048b02d4a | ||
|
|
f94092c169 | ||
|
|
1c79e3dec1 | ||
|
|
527ae8e557 | ||
|
|
890b1e473d | ||
|
|
6447352439 | ||
|
|
788544d936 | ||
|
|
59d108a807 | ||
|
|
218ebed3ef | ||
|
|
670d7f485f | ||
|
|
c993fea8ab | ||
|
|
066b81aec2 | ||
|
|
dcb02a951d | ||
|
|
ac0fd71f0a | ||
|
|
f98f01e81d | ||
|
|
23375cce3a | ||
|
|
8ffc00dbcd | ||
|
|
ec40fc41b5 | ||
|
|
5ec70f704e | ||
|
|
4c0ac93eb6 | ||
|
|
788dde3a34 | ||
|
|
e05d22cb7b | ||
|
|
3483e4441e | ||
|
|
2a76135b82 | ||
|
|
6a9834e8b6 | ||
|
|
a4d3a414ca | ||
|
|
a49760e2ba | ||
|
|
4e01f87a6e | ||
|
|
c8a5df963b | ||
|
|
18209e6194 | ||
|
|
4a466d94b6 | ||
|
|
9287c36f37 | ||
|
|
30ffa259b7 | ||
|
|
bee74c3eab | ||
|
|
83bf24cc9a | ||
|
|
3dbc3e60fb | ||
|
|
830a3b9f27 | ||
|
|
69b1f7b118 | ||
|
|
66454a0fbf | ||
|
|
012d428f7b | ||
|
|
1c17419224 | ||
|
|
9dde8829e6 | ||
|
|
0f66bbe2f9 | ||
|
|
6de5670912 | ||
|
|
5e39b4ce94 | ||
|
|
0a1da47527 | ||
|
|
6b482a93d6 | ||
|
|
d9b9cc80da | ||
|
|
c3e98db37d | ||
|
|
01d0b7b102 | ||
|
|
848a494ff6 | ||
|
|
378c147be6 | ||
|
|
d4fbf6ef39 | ||
|
|
8c1503dafa | ||
|
|
ba022dd091 | ||
|
|
13a1f68b8e | ||
|
|
58795d72c8 | ||
|
|
220997ff47 | ||
|
|
ee2566456a | ||
|
|
a231930044 | ||
|
|
6f0fc7f386 | ||
|
|
fde67dbae7 | ||
|
|
ad1ad11eac | ||
|
|
01bc89b6f4 | ||
|
|
8c43b3d05e | ||
|
|
d4af22418b | ||
|
|
eaec52a7b7 | ||
|
|
0a390de361 | ||
|
|
20b74ae1eb | ||
|
|
b9b880bd8b | ||
|
|
2866d0770f | ||
|
|
4375a05a9f | ||
|
|
4acf99f622 | ||
|
|
5a6ea09248 | ||
|
|
9c0836c8d0 | ||
|
|
b0cca75e5e | ||
|
|
54b5c805bf | ||
|
|
eab5543750 | ||
|
|
6b6a990f4c | ||
|
|
c2a05a1fde | ||
|
|
6c4d122198 | ||
|
|
34c5d4ce07 | ||
|
|
c1b28f0b58 | ||
|
|
53ecec5fb2 | ||
|
|
65738f0a80 | ||
|
|
5d184a7811 | ||
|
|
1a5c1ef9c7 | ||
|
|
7866c1f7d1 | ||
|
|
3666ac9346 | ||
|
|
3daab2acbb | ||
|
|
c36d2253d0 | ||
|
|
e2e6f6e666 | ||
|
|
ff0029f84b | ||
|
|
39ad2d16d4 | ||
|
|
689c5efc72 | ||
|
|
eda0b996cd | ||
|
|
15e7a9d541 | ||
|
|
52fb4143b5 | ||
|
|
93c80b2cb1 | ||
|
|
5fbbaa1bc0 | ||
|
|
71d1f5e2c9 | ||
|
|
b520941cd9 | ||
|
|
64ed5258e6 | ||
|
|
392a8c32a7 | ||
|
|
969ef745a2 | ||
|
|
6fe42a72db | ||
|
|
2487228ea7 | ||
|
|
76436ca1de | ||
|
|
fbf2f2222a | ||
|
|
02bc4e03e0 | ||
|
|
624eaf1175 | ||
|
|
aed3eb4a94 | ||
|
|
8426c64f42 | ||
|
|
7c2bbee613 | ||
|
|
9d6886dd08 | ||
|
|
d67ca342e9 | ||
|
|
57c9c21c39 | ||
|
|
38c14571cc |
@@ -19,13 +19,9 @@
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
- local: porting_datasets_v3
|
||||
title: Porting Large Datasets
|
||||
title: "Datasets"
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
@@ -41,14 +37,10 @@
|
||||
title: Koch v1.1
|
||||
- local: lekiwi
|
||||
title: LeKiwi
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Feetech Motor Firmware Update
|
||||
|
||||
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Windows computer (Feetech software is only available for Windows)
|
||||
- Feetech motor control board
|
||||
- USB cable to connect the control board to your computer
|
||||
- Feetech motors connected to the control board
|
||||
|
||||
## Step 1: Download Feetech Software
|
||||
|
||||
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
|
||||
2. Download the latest version of the Feetech debugging software (FD)
|
||||
3. Install the software on your Windows computer
|
||||
|
||||
## Step 2: Hardware Setup
|
||||
|
||||
1. Connect your Feetech motors to the motor control board
|
||||
2. Connect the motor control board to your Windows computer via USB cable
|
||||
3. Ensure power is supplied to the motors
|
||||
|
||||
## Step 3: Configure Connection
|
||||
|
||||
1. Launch the Feetech debugging software
|
||||
2. Select the correct COM port from the port dropdown menu
|
||||
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
|
||||
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
|
||||
4. Click "Open" to establish communication with the control board
|
||||
|
||||
## Step 4: Scan for Motors
|
||||
|
||||
1. Once connected, click the "Search" button to detect all connected motors
|
||||
2. The software will automatically discover and list all motors on the bus
|
||||
3. Each motor will appear with its ID number
|
||||
|
||||
## Step 5: Update Firmware
|
||||
|
||||
For each motor you want to update:
|
||||
|
||||
1. **Select the motor** from the list by clicking on it
|
||||
2. **Click on Upgrade tab**:
|
||||
3. **Click on Online button**:
|
||||
- If an potential firmware update is found, it will be displayed in the box
|
||||
4. **Click on Upgrade button**:
|
||||
- The update progress will be displayed
|
||||
|
||||
## Step 6: Verify Update
|
||||
|
||||
1. After the update completes, the software should automatically refresh the motor information
|
||||
2. Verify that the firmware version has been updated to the expected version
|
||||
|
||||
## Important Notes
|
||||
|
||||
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
|
||||
|
||||
## Bonus: Motor Debugging on Linux/macOS
|
||||
|
||||
For debugging purposes only, you can use the open-source Feetech Debug Tool:
|
||||
|
||||
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
|
||||
|
||||
### Installation Instructions
|
||||
|
||||
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
|
||||
|
||||
**Limitations:**
|
||||
|
||||
- This tool is for debugging and parameter adjustment only
|
||||
- Firmware updates must still be done on Windows with official Feetech software
|
||||
@@ -1,169 +0,0 @@
|
||||
# LeRobotDataset v3.0
|
||||
|
||||
`LeRobotDataset v3.0` is a standardized format for robot learning data. It provides unified access to multi-modal time-series data, sensorimotor signals and multi‑camera video, as well as rich metadata for indexing, search, and visualization on the Hugging Face Hub.
|
||||
|
||||
This docs will guide you to:
|
||||
|
||||
- Understand the v3.0 design and directory layout
|
||||
- Record a dataset and push it to the Hub
|
||||
- Load datasets for training with `LeRobotDataset`
|
||||
- Stream datasets without downloading using `StreamingLeRobotDataset`
|
||||
- Migrate existing `v2.1` datasets to `v3.0`
|
||||
|
||||
## What’s new in `v3`
|
||||
|
||||
- **File-based storage**: Many episodes per Parquet/MP4 file (v2 used one file per episode).
|
||||
- **Relational metadata**: Episode boundaries and lookups are resolved through metadata, not filenames.
|
||||
- **Hub-native streaming**: Consume datasets directly from the Hub with `StreamingLeRobotDataset`.
|
||||
- **Lower file-system pressure**: Fewer, larger files ⇒ faster initialization and fewer issues at scale.
|
||||
- **Unified organization**: Clean directory layout with consistent path templates across data and videos.
|
||||
|
||||
## Installation
|
||||
|
||||
`LeRobotDataset v3.0` will be included in `lerobot >= 0.4.0`.
|
||||
|
||||
Until that stable release, you can use the main branch by following the [build from source instructions](./installation#from-source).
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Run the command below to record a dataset with the SO-101 and push to the Hub:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
|
||||
See the [recording guide](./il_robots#record-a-dataset) for more details.
|
||||
|
||||
## Format design
|
||||
|
||||
A core v3 principle is **decoupling storage from the user API**: data is stored efficiently (few large files), while the public API exposes intuitive episode-level access.
|
||||
|
||||
`v3` has three pillars:
|
||||
|
||||
1. **Tabular data**: Low‑dimensional, high‑frequency signals (states, actions, timestamps) stored in **Apache Parquet**. Access is memory‑mapped or streamed via the `datasets` stack.
|
||||
2. **Visual data**: Camera frames concatenated and encoded into **MP4**. Frames from the same episode are grouped; videos are sharded per camera for practical sizes.
|
||||
3. **Metadata**: JSON/Parquet records describing schema (feature names, dtypes, shapes), frame rates, normalization stats, and **episode segmentation** (start/end offsets into shared Parquet/MP4 files).
|
||||
|
||||
> To scale to millions of episodes, tabular rows and video frames from multiple episodes are **concatenated** into larger files. Episode‑specific views are reconstructed **via metadata**, not file boundaries.
|
||||
|
||||
<div style="display:flex; justify-content:center; gap:12px; flex-wrap:wrap;">
|
||||
<figure style="margin:0; text-align:center;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobotdataset-v3/asset1datasetv3.png"
|
||||
alt="LeRobotDataset v3 diagram"
|
||||
width="220"
|
||||
/>
|
||||
<figcaption style="font-size:0.9em; color:#666;">
|
||||
From episode‑based to file‑based datasets
|
||||
</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
### Directory layout (simplified)
|
||||
|
||||
- **`meta/info.json`**: canonical schema (features, shapes/dtypes), FPS, codebase version, and **path templates** to locate data/video shards.
|
||||
- **`meta/stats.json`**: global feature statistics (mean/std/min/max) used for normalization; exposed as `dataset.meta.stats`.
|
||||
- **`meta/tasks.jsonl`**: natural‑language task descriptions mapped to integer IDs for task‑conditioned policies.
|
||||
- **`meta/episodes/`**: per‑episode records (lengths, tasks, offsets) stored as **chunked Parquet** for scalability.
|
||||
- **`data/`**: frame‑by‑frame **Parquet** shards; each file typically contains **many episodes**.
|
||||
- **`videos/`**: **MP4** shards per camera; each file typically contains **many episodes**.
|
||||
|
||||
## Load a dataset for training
|
||||
|
||||
`LeRobotDataset` returns Python dictionaries of PyTorch tensors and integrates with `torch.utils.data.DataLoader`. Here is a code example showing its use:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
repo_id = "yaak-ai/L2D-v3"
|
||||
|
||||
# 1) Load from the Hub (cached locally)
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# 2) Random access by index
|
||||
sample = dataset[100]
|
||||
print(sample)
|
||||
# {
|
||||
# 'observation.state': tensor([...]),
|
||||
# 'action': tensor([...]),
|
||||
# 'observation.images.front_left': tensor([C, H, W]),
|
||||
# 'timestamp': tensor(1.234),
|
||||
# ...
|
||||
# }
|
||||
|
||||
# 3) Temporal windows via delta_timestamps (seconds relative to t)
|
||||
delta_timestamps = {
|
||||
"observation.images.front_left": [-0.2, -0.1, 0.0] # 0.2s and 0.1s before current frame
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Accessing an index now returns a stack for the specified key(s)
|
||||
sample = dataset[100]
|
||||
print(sample["observation.images.front_left"].shape) # [T, C, H, W], where T=3
|
||||
|
||||
# 4) Wrap with a DataLoader for training
|
||||
batch_size = 16
|
||||
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
for batch in data_loader:
|
||||
observations = batch["observation.state"].to(device)
|
||||
actions = batch["action"].to(device)
|
||||
images = batch["observation.images.front_left"].to(device)
|
||||
# model.forward(batch)
|
||||
```
|
||||
|
||||
## Stream a dataset (no downloads)
|
||||
|
||||
Use `StreamingLeRobotDataset` to iterate directly from the Hub without local copies. This allows to stream large datasets without the need to downloading them onto disk or loading them onto memory, and is a key feature of the new dataset format.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
repo_id = "yaak-ai/L2D-v3"
|
||||
dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub
|
||||
```
|
||||
|
||||
<div style="display:flex; justify-content:center; gap:12px; flex-wrap:wrap;">
|
||||
<figure style="margin:0; text-align:center;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobotdataset-v3/streaming-lerobot.png"
|
||||
alt="StreamingLeRobotDataset"
|
||||
width="520"
|
||||
/>
|
||||
<figcaption style="font-size:0.9em; color:#666;">
|
||||
Stream directly from the Hub for on‑the‑fly training.
|
||||
</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
## Migrate `v2.1` → `v3.0`
|
||||
|
||||
A converter aggregates per‑episode files into larger shards and writes episode offsets/metadata. Convert your dataset using the instructions below.
|
||||
|
||||
```bash
|
||||
# Pre-release build with v3 support:
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
|
||||
- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, …
|
||||
- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, …
|
||||
- Updates `meta/episodes/*` (chunked Parquet) with per‑episode lengths, tasks, and byte/frame offsets.
|
||||
@@ -150,7 +150,7 @@ gsutil -m cp -r gs://gresearch/robotics/droid_100 /your/data/
|
||||
### Step 3: Port the Dataset
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/port_droid.py \
|
||||
python examples/port_datasets/port_droid_rlds.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--push-to-hub
|
||||
@@ -161,7 +161,7 @@ python examples/port_datasets/port_droid.py \
|
||||
For development, you can port a single shard:
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/port_droid.py \
|
||||
python examples/port_datasets/port_droid_rlds.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1_test \
|
||||
--num-shards 2048 \
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
# Reachy 2
|
||||
|
||||
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
|
||||
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
|
||||
|
||||
## Teleoperate Reachy 2
|
||||
|
||||
Currently, there are two ways to teleoperate Reachy 2:
|
||||
|
||||
- Pollen Robotics’ VR teleoperation (not included in LeRobot).
|
||||
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
|
||||
|
||||
## Reachy 2 Simulation
|
||||
|
||||
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
|
||||
|
||||
1. Install [Docker Engine](https://docs.docker.com/engine/).
|
||||
2. Run (for MuJoCo):
|
||||
|
||||
```
|
||||
docker run --rm -it \
|
||||
--name reachy \
|
||||
--privileged \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--device-cgroup-rule='c 189:* rwm' \
|
||||
--group-add audio \
|
||||
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
-e DISPLAY="$DISPLAY" \
|
||||
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
-v /dev:/dev \
|
||||
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
--entrypoint /package/launch.sh \
|
||||
pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
```
|
||||
|
||||
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
|
||||
>
|
||||
> ```
|
||||
> docker run --rm -it \
|
||||
> --name reachy \
|
||||
> --privileged \
|
||||
> --network host \
|
||||
> --ipc host \
|
||||
> --device-cgroup-rule='c 189:* rwm' \
|
||||
> --group-add audio \
|
||||
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
> -e DISPLAY="$DISPLAY" \
|
||||
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
|
||||
> -v /dev:/dev \
|
||||
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
> --entrypoint /package/launch.sh \
|
||||
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
> start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
> ```
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- On your robot, check the **service images** meet the minimum versions:
|
||||
- **reachy2-core >= 1.7.5.2**
|
||||
- **webrtc >= 2.0.1.1**
|
||||
|
||||
Then, if you want to use VR teleoperation:
|
||||
|
||||
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
|
||||
Use version **>=v1.2.0**
|
||||
|
||||
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
|
||||
|
||||
Install LeRobot with Reachy 2 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[reachy2]"
|
||||
```
|
||||
|
||||
### (Optional but recommended) Install pollen_data_acquisition_server
|
||||
|
||||
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
|
||||
|
||||
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs.
|
||||
|
||||
In your LeRobot environment, install the server from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
|
||||
cd pollen_data_acquisition_server
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
|
||||
|
||||
## Step 1: Recording
|
||||
|
||||
### Get Reachy 2 IP address
|
||||
|
||||
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
|
||||
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
|
||||
|
||||
### Launch recording
|
||||
|
||||
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
|
||||
|
||||
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions.
|
||||
- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control.
|
||||
|
||||
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
|
||||
|
||||
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
|
||||
|
||||
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
|
||||
|
||||
```bash
|
||||
python -m pollen_data_acquisition_server.server
|
||||
```
|
||||
|
||||
Then get into the teleoperation application and choose "Data acquisition session".
|
||||
You can finally setup your session by following the screens displayed.
|
||||
|
||||
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
|
||||
|
||||
### Option 2: Using lerobot.record
|
||||
|
||||
Reachy 2 is fully supported by LeRobot’s recording features.
|
||||
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
|
||||
|
||||
**Example: start a recording without the mobile base:**
|
||||
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.id=r2-0000 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=false \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
#### Specific Options
|
||||
|
||||
**Extended setup overview (all options included):**
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=true \
|
||||
--robot.with_l_arm=true \
|
||||
--robot.with_r_arm=true \
|
||||
--robot.with_neck=true \
|
||||
--robot.with_antennas=true \
|
||||
--robot.with_left_teleop_camera=true \
|
||||
--robot.with_right_teleop_camera=true \
|
||||
--robot.with_torso_camera=false \
|
||||
--robot.disable_torque_on_disconnect=false \
|
||||
--robot.max_relative_target=5.0 \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.use_present_position=false \
|
||||
--teleop.with_mobile_base=false \
|
||||
--teleop.with_l_arm=true \
|
||||
--teleop.with_r_arm=true \
|
||||
--teleop.with_neck=true \
|
||||
--teleop.with_antennas=true \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
##### `--robot.use_external_commands`
|
||||
|
||||
Determine whether LeRobot robot.send_action() sends commands to the robot.
|
||||
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
|
||||
|
||||
##### `--teleop.use_present_position`
|
||||
|
||||
Determine whether the teleoperator reads the goal or present position of the robot.
|
||||
Must be set to true if a compliant Reachy 2 is used to control another one.
|
||||
|
||||
##### Use the relevant parts
|
||||
|
||||
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
|
||||
To avoid this, you can exclude specific parts from recording and replay using:
|
||||
|
||||
````
|
||||
--robot.with_<part>=false
|
||||
```,
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
It determine whether the corresponding part is recorded in the observations. True if not set.
|
||||
|
||||
By default, **all parts are recorded**.
|
||||
|
||||
The same per-part mechanism is available in `reachy2_teleoperator` as well.
|
||||
|
||||
````
|
||||
|
||||
--teleop.with\_<part>
|
||||
|
||||
```
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
Determine whether the corresponding part is recorded in the actions. True if not set.
|
||||
|
||||
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
|
||||
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
|
||||
##### Use the relevant cameras
|
||||
|
||||
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
|
||||
|
||||
```
|
||||
|
||||
--robot.with_left_teleop_camera=<true|false>
|
||||
--robot.with_right_teleop_camera=<true|false>
|
||||
--robot.with_torso_camera=<true|false>
|
||||
|
||||
````
|
||||
|
||||
|
||||
## Step 2: Replay
|
||||
|
||||
Make sure the robot is configured with the same parts as the dataset:
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=false \
|
||||
--robot.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.episode=0
|
||||
--display_data=true
|
||||
````
|
||||
|
||||
## Step 3: Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/reachy2_test \
|
||||
--job_name=reachy2 \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=pollen_robotics/record_test_policy
|
||||
```
|
||||
|
||||
## Step 4: Evaluate
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=pollen_robotics/eval_record_test \
|
||||
--dataset.single_task="Evaluate reachy2 policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This script demonstrates how to train a Diffusion Policy on the PushT environment,
|
||||
using a dataset processed in streaming mode.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_streaming_dataset")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Selects the "best" device available
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
training_steps = 10
|
||||
log_freq = 1
|
||||
|
||||
dataset_id = (
|
||||
"aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (:
|
||||
)
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy.
|
||||
# Here, we use delta-timestamps to only provide ground truth actions for supervision
|
||||
delta_timestamps = {
|
||||
ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)],
|
||||
}
|
||||
|
||||
# Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched
|
||||
# iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs
|
||||
dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3)
|
||||
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=16,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
prefetch_factor=2, # loads batches with multiprocessing while policy trains
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {
|
||||
k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v)
|
||||
for k, v in batch.items()
|
||||
}
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
|
||||
# batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
examples/port_datasets/convert_rt1_example.sh
Normal file
47
examples/port_datasets/convert_rt1_example.sh
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script for converting RT-1 dataset using SLURM
|
||||
# Make sure to modify the paths and parameters according to your setup
|
||||
|
||||
# Configuration
|
||||
RAW_DIR="/path/to/datasets/fractal20220817_data/0.1.0"
|
||||
REPO_ID="your_username/rt1_lerobot"
|
||||
LOGS_DIR="/path/to/logs"
|
||||
PARTITION="cpu" # Your SLURM partition name
|
||||
|
||||
# Step 1: Convert dataset using distributed processing
|
||||
echo "Starting RT-1 dataset conversion..."
|
||||
python examples/port_datasets/slurm_port_shards.py \
|
||||
--raw-dir "$RAW_DIR" \
|
||||
--repo-id "$REPO_ID" \
|
||||
--dataset-type rlds \
|
||||
--logs-dir "$LOGS_DIR" \
|
||||
--job-name rt1_conversion \
|
||||
--workers 32 \
|
||||
--num-shards 32 \
|
||||
--partition "$PARTITION" \
|
||||
--cpus-per-task 4 \
|
||||
--mem-per-cpu 2G \
|
||||
--slurm 1
|
||||
|
||||
# Step 2: Wait for jobs to complete (you can monitor with squeue)
|
||||
echo "Conversion jobs submitted. Monitor with 'squeue -u \$USER'"
|
||||
echo "Once all jobs complete, run the aggregation step:"
|
||||
echo ""
|
||||
echo "python examples/port_datasets/slurm_aggregate_shards.py \\"
|
||||
echo " --repo-id $REPO_ID \\"
|
||||
echo " --push-to-hub"
|
||||
|
||||
# Uncomment the following lines if you want to automatically aggregate
|
||||
# (but make sure all shards are complete first)
|
||||
|
||||
# echo "Waiting for jobs to complete..."
|
||||
# while [ $(squeue -u $USER -h | wc -l) -gt 0 ]; do
|
||||
# echo "Jobs still running, waiting 60 seconds..."
|
||||
# sleep 60
|
||||
# done
|
||||
|
||||
# echo "All jobs completed. Starting aggregation..."
|
||||
# python examples/port_datasets/slurm_aggregate_shards.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --push-to-hub
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,5 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_reachy2_camera import Reachy2CameraConfig
|
||||
from .reachy2_camera import Reachy2Camera
|
||||
"""Open X-Embodiment utilities for dataset conversion."""
|
||||
854
examples/port_datasets/oxe_utils/configs.py
Normal file
854
examples/port_datasets/oxe_utils/configs.py
Normal file
@@ -0,0 +1,854 @@
|
||||
"""
|
||||
Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py
|
||||
configs.py
|
||||
|
||||
Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment.
|
||||
|
||||
Configuration adopts the following structure:
|
||||
image_obs_keys:
|
||||
primary: primary external RGB
|
||||
secondary: secondary external RGB
|
||||
wrist: wrist RGB
|
||||
|
||||
depth_obs_keys:
|
||||
primary: primary external depth
|
||||
secondary: secondary external depth
|
||||
wrist: wrist depth
|
||||
|
||||
# Always 8-dim =>> changes based on `StateEncoding`
|
||||
state_obs_keys:
|
||||
StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
|
||||
StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
|
||||
StateEncoding.JOINT: Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
|
||||
|
||||
state_encoding: Type of `StateEncoding`
|
||||
action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position)
|
||||
"""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def zero_action_filter(traj: dict) -> bool:
|
||||
"""
|
||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
||||
"""
|
||||
DROID_Q01 = tf.convert_to_tensor( # NOQA: N806
|
||||
[
|
||||
-0.7776297926902771,
|
||||
-0.5803514122962952,
|
||||
-0.5795090794563293,
|
||||
-0.6464047729969025,
|
||||
-0.7041108310222626,
|
||||
-0.8895104378461838,
|
||||
]
|
||||
)
|
||||
DROID_Q99 = tf.convert_to_tensor( # NOQA: N806
|
||||
[
|
||||
0.7597932070493698,
|
||||
0.5726242214441299,
|
||||
0.7351000607013702,
|
||||
0.6705610305070877,
|
||||
0.6464948207139969,
|
||||
0.8897542208433151,
|
||||
]
|
||||
)
|
||||
DROID_NORM_0_ACT = ( # NOQA: N806
|
||||
2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
||||
)
|
||||
|
||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
||||
|
||||
|
||||
# Defines Proprioceptive State Encoding Schemes
|
||||
class StateEncoding(IntEnum):
|
||||
# fmt: off
|
||||
NONE = -1 # No Proprioceptive State
|
||||
POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
|
||||
POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
|
||||
JOINT = 3 # Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
|
||||
JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ])
|
||||
# fmt: on
|
||||
|
||||
|
||||
# Defines Action Encoding Schemes
|
||||
class ActionEncoding(IntEnum):
|
||||
# fmt: off
|
||||
EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1)
|
||||
EEF_POS_QUAT = 5 # EEF Delta XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
|
||||
JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1)
|
||||
JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
|
||||
EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1)
|
||||
# fmt: on
|
||||
|
||||
|
||||
# === Individual Dataset Configs ===
|
||||
OXE_DATASET_CONFIGS = {
|
||||
"fractal20220817_data": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3,
|
||||
"robot_type": "Google Robot",
|
||||
},
|
||||
"kuka": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [
|
||||
"clip_function_input/base_pose_tool_reached",
|
||||
"gripper_closed",
|
||||
],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Kuka iiwa",
|
||||
},
|
||||
"bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture
|
||||
"image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "WidowX",
|
||||
},
|
||||
"bridge_orig": { # Original version of Bridge V2 from project website
|
||||
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "WidowX",
|
||||
},
|
||||
"bridge_dataset": { # Original version of Bridge V2 from project website
|
||||
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "WidowX",
|
||||
},
|
||||
"taco_play": {
|
||||
"image_obs_keys": {
|
||||
"primary": "rgb_static",
|
||||
"secondary": None,
|
||||
"wrist": "rgb_gripper",
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "depth_static",
|
||||
"secondary": None,
|
||||
"wrist": "depth_gripper",
|
||||
},
|
||||
"state_obs_keys": ["state_eef", None, "state_gripper"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 15,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"jaco_play": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "image_wrist",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state_eef", None, "state_gripper"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Jaco 2",
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": "top_image",
|
||||
"wrist": "wrist45_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["robot_state", None],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"roboturk": {
|
||||
"image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Sawyer",
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3,
|
||||
"robot_type": "Hello Stretch",
|
||||
},
|
||||
"viola": {
|
||||
"image_obs_keys": {
|
||||
"primary": "agentview_rgb",
|
||||
"secondary": None,
|
||||
"wrist": "eye_in_hand_rgb",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_states", "gripper_states"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "hand_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "UR5",
|
||||
},
|
||||
"toto": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 30,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"language_table": {
|
||||
"image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["effector_translation", None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "xArm",
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["robot_state", None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "UR5",
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["ee_position", "ee_orientation", None],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Kuka iiwa",
|
||||
},
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3,
|
||||
"robot_type": "xArm",
|
||||
},
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"austin_buds_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": "image_additional_view",
|
||||
"wrist": None,
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "depth",
|
||||
"secondary": "depth_additional_view",
|
||||
"wrist": None,
|
||||
},
|
||||
"state_obs_keys": ["eef_state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"maniskill_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "depth",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_depth",
|
||||
},
|
||||
"state_obs_keys": ["tcp_pose", "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "highres_image",
|
||||
"secondary": None,
|
||||
"wrist": None,
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_state", None],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 2,
|
||||
"robot_type": "xArm",
|
||||
},
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3,
|
||||
"robot_type": "xArm",
|
||||
},
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"bc_z": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [
|
||||
"present/xyz",
|
||||
"present/axis_angle",
|
||||
None,
|
||||
"present/sensed_close",
|
||||
],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Google Robot",
|
||||
},
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "PR2",
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "PR2",
|
||||
},
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": "image2",
|
||||
"wrist": "hand_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["end_effector_pose", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "xArm",
|
||||
},
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["pose_r", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "xArm Bimanual",
|
||||
},
|
||||
"robo_net": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 1,
|
||||
"robot_type": "Multi-Robot",
|
||||
},
|
||||
"berkeley_mvp_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["pose", "gripper"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.JOINT_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "xArm",
|
||||
},
|
||||
"berkeley_rpt_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_pos", "gripper"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.JOINT_POS,
|
||||
"control_frequency": 30,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"stanford_mask_vit_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": None,
|
||||
"robot_type": "Sawyer",
|
||||
},
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Cobotta",
|
||||
},
|
||||
"dlr_sara_pour_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "DLR SARA",
|
||||
},
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "DLR SARA",
|
||||
},
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "DLR EDAN",
|
||||
},
|
||||
"asu_table_top_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 12.5,
|
||||
"robot_type": "UR5",
|
||||
},
|
||||
"stanford_robocook_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
|
||||
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, "state"],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Sawyer",
|
||||
},
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_state", "gripper_state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
|
||||
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 1,
|
||||
"robot_type": "Kinova Gen3",
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Fanuc Mate",
|
||||
},
|
||||
"cmu_playing_with_food": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "finger_vision_1",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Hello Stretch",
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3,
|
||||
"robot_type": "Jackal",
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "RC Car",
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "TurtleBot 2",
|
||||
},
|
||||
# NOTE: modified
|
||||
"droid": {
|
||||
"image_obs_keys": {
|
||||
"primary": "exterior_image_1_left",
|
||||
"secondary": "exterior_image_2_left",
|
||||
"wrist": "wrist_image_left",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 15,
|
||||
"robot_type": "Franka",
|
||||
"aux_kwargs": {
|
||||
"dataset_frame_transform_kwargs": {
|
||||
"chunk_filter_fn": zero_action_filter,
|
||||
},
|
||||
},
|
||||
},
|
||||
"fmb_dataset": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image_side_1",
|
||||
"secondary": "image_side_2",
|
||||
"wrist": "image_wrist_1",
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "image_side_1_depth",
|
||||
"secondary": "image_side_2_depth",
|
||||
"wrist": "image_wrist_1_depth",
|
||||
},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
# NOTE: modified
|
||||
"dobbe": {
|
||||
"image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 3.75,
|
||||
"robot_type": "Hello Stretch",
|
||||
},
|
||||
"roboset": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image_left",
|
||||
"secondary": "image_right",
|
||||
"wrist": "image_wrist",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.JOINT_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"rh20t": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image_front",
|
||||
"secondary": "image_side_right",
|
||||
"wrist": "image_wrist",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 10,
|
||||
"robot_type": "Flexiv",
|
||||
},
|
||||
### T-DROID datasets
|
||||
"tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"tdroid_pour_corn_in_pot": { # "pour corn from red bonawl into steel pot" task, 50 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 5,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
### DROID Finetuning datasets
|
||||
"droid_wipe": {
|
||||
"image_obs_keys": {
|
||||
"primary": "exterior_image_2_left",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image_left",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 15,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
# NOTE: modified
|
||||
### LIBERO datasets (modified versions)
|
||||
"libero_spatial_no_noops": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"libero_object_no_noops": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"libero_goal_no_noops": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
"libero_10_no_noops": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"control_frequency": 20,
|
||||
"robot_type": "Franka",
|
||||
},
|
||||
}
|
||||
76
examples/port_datasets/oxe_utils/transform_utils.py
Normal file
76
examples/port_datasets/oxe_utils/transform_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts gripper actions from continuous to binary values (0 and 1).
|
||||
|
||||
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
||||
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
||||
values based on the state that is reached _after_ those intermediate values.
|
||||
|
||||
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
||||
chunk of intermediate values as the last action in the trajectory.
|
||||
|
||||
The `scan_fn` implements the following logic:
|
||||
new_actions = np.empty_like(actions)
|
||||
carry = actions[-1]
|
||||
for i in reversed(range(actions.shape[0])):
|
||||
if in_between_mask[i]:
|
||||
carry = carry
|
||||
else:
|
||||
carry = float(open_mask[i])
|
||||
new_actions[i] = carry
|
||||
"""
|
||||
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
||||
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
||||
is_open_float = tf.cast(open_mask, tf.float32)
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
||||
|
||||
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
||||
|
||||
|
||||
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
return 1 - actions
|
||||
|
||||
|
||||
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
||||
|
||||
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
||||
"""
|
||||
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
||||
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
||||
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
||||
|
||||
# If no relative grasp, assumes open for whole trajectory
|
||||
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
||||
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
||||
|
||||
# Note =>> -1 for closed, 1 for open
|
||||
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
||||
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
||||
|
||||
return new_actions
|
||||
|
||||
|
||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||
def relabel_bridge_actions(traj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
||||
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
||||
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
||||
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
||||
|
||||
return traj_truncated
|
||||
1006
examples/port_datasets/oxe_utils/transforms.py
Normal file
1006
examples/port_datasets/oxe_utils/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
359
examples/port_datasets/port_rlds.py
Normal file
359
examples/port_datasets/port_rlds.py
Normal file
@@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
|
||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||
|
||||
# Default FPS for datasets without specific config
|
||||
DEFAULT_FPS = 10
|
||||
DEFAULT_ROBOT_TYPE = "unknown"
|
||||
|
||||
|
||||
def determine_dataset_info(raw_dir: Path):
|
||||
"""Determine dataset name and version from directory structure."""
|
||||
last_part = raw_dir.name
|
||||
if re.match(r"^\d+\.\d+\.\d+$", last_part):
|
||||
version = last_part
|
||||
dataset_name = raw_dir.parent.name
|
||||
data_dir = raw_dir.parent.parent
|
||||
else:
|
||||
version = ""
|
||||
dataset_name = last_part
|
||||
data_dir = raw_dir.parent
|
||||
return dataset_name, version, data_dir
|
||||
|
||||
|
||||
def generate_features_from_builder(builder, dataset_name: str) -> dict[str, Any]:
|
||||
"""Generate LeRobot features schema from TFDS builder and dataset config."""
|
||||
|
||||
# Generate state names based on encoding type
|
||||
state_names = [f"motor_{i}" for i in range(8)]
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
|
||||
if state_encoding == StateEncoding.POS_EULER:
|
||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
|
||||
if "libero" in dataset_name:
|
||||
state_names = [
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"roll",
|
||||
"pitch",
|
||||
"yaw",
|
||||
"gripper",
|
||||
"gripper",
|
||||
] # 2D gripper state
|
||||
elif state_encoding == StateEncoding.POS_QUAT:
|
||||
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
||||
elif state_encoding == StateEncoding.JOINT:
|
||||
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
||||
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
|
||||
pad_count = state_obs_keys[:-1].count(None)
|
||||
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
|
||||
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
|
||||
|
||||
# Generate action names based on encoding type
|
||||
action_names = [f"motor_{i}" for i in range(8)]
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
|
||||
if action_encoding == ActionEncoding.EEF_POS:
|
||||
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
|
||||
elif action_encoding == ActionEncoding.JOINT_POS:
|
||||
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
||||
|
||||
# Base features (state and action)
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": {"axes": state_names},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": {"axes": action_names},
|
||||
},
|
||||
}
|
||||
|
||||
# Add image features from TFDS builder info
|
||||
obs_features = builder.info.features["steps"]["observation"]
|
||||
for key, value in obs_features.items():
|
||||
# Skip depth images and non-image features
|
||||
if "depth" in key or not any(x in key for x in ["image", "rgb"]):
|
||||
continue
|
||||
|
||||
features[f"observation.images.{key}"] = {
|
||||
"dtype": "video",
|
||||
"shape": tuple(value.shape),
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def transform_raw_dataset(episode, dataset_name: str):
|
||||
"""Apply OXE standardization transforms to raw TFDS episode."""
|
||||
# Batch all steps in the episode
|
||||
traj = next(iter(episode["steps"].batch(episode["steps"].cardinality())))
|
||||
|
||||
# Apply dataset-specific transform if available
|
||||
if dataset_name in OXE_STANDARDIZATION_TRANSFORMS:
|
||||
traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj)
|
||||
|
||||
# Create consolidated state vector
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
|
||||
else:
|
||||
state_obs_keys = [None for _ in range(8)]
|
||||
|
||||
# Build proprio (proprioceptive state) vector
|
||||
proprio_components = []
|
||||
for key in state_obs_keys:
|
||||
if key is None:
|
||||
# Add padding for missing state components
|
||||
component = tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32)
|
||||
else:
|
||||
component = tf.cast(traj["observation"][key], tf.float32)
|
||||
# Ensure component has right shape (add dimension if needed)
|
||||
if len(component.shape) == 1:
|
||||
component = component[:, None]
|
||||
proprio_components.append(component)
|
||||
|
||||
proprio = tf.concat(proprio_components, axis=1)
|
||||
|
||||
# Update trajectory with standardized format
|
||||
traj.update(
|
||||
{
|
||||
"proprio": proprio,
|
||||
"task": traj.get("language_instruction", ""),
|
||||
"action": tf.cast(traj["action"], tf.float32),
|
||||
}
|
||||
)
|
||||
|
||||
episode["steps"] = traj
|
||||
return episode
|
||||
|
||||
|
||||
def generate_lerobot_frames(tf_episode):
|
||||
"""Generate LeRobot frames from transformed TFDS episode."""
|
||||
traj = tf_episode["steps"]
|
||||
|
||||
# Get the task/language instruction
|
||||
if isinstance(traj["task"], tf.Tensor):
|
||||
if traj["task"].dtype == tf.string:
|
||||
task = traj["task"][0].numpy().decode() if len(traj["task"]) > 0 else ""
|
||||
else:
|
||||
task = str(traj["task"][0].numpy()) if len(traj["task"]) > 0 else ""
|
||||
else:
|
||||
task = str(traj["task"]) if traj["task"] else ""
|
||||
|
||||
# Iterate through each timestep
|
||||
num_steps = tf.shape(traj["action"])[0].numpy()
|
||||
for i in range(num_steps):
|
||||
frame = {}
|
||||
|
||||
# Add observation state
|
||||
frame["observation.state"] = traj["proprio"][i].numpy()
|
||||
|
||||
# Add action
|
||||
frame["action"] = traj["action"][i].numpy()
|
||||
|
||||
# Add images
|
||||
for key, value in traj["observation"].items():
|
||||
if any(x in key for x in ["image", "rgb"]) and "depth" not in key:
|
||||
frame[f"observation.images.{key}"] = value[i].numpy()
|
||||
|
||||
# Add task
|
||||
frame["task"] = task
|
||||
|
||||
# Cast fp64 to fp32
|
||||
for key in frame:
|
||||
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
|
||||
frame[key] = frame[key].astype(np.float32)
|
||||
|
||||
yield frame
|
||||
|
||||
|
||||
def port_rlds(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = False,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
"""Port RLDS dataset to LeRobot format."""
|
||||
|
||||
# Determine dataset info
|
||||
dataset_name, version, data_dir = determine_dataset_info(raw_dir)
|
||||
|
||||
# Build TFDS dataset
|
||||
builder = tfds.builder(
|
||||
f"{dataset_name}/{version}" if version else dataset_name, data_dir=data_dir, version=version
|
||||
)
|
||||
|
||||
# Handle sharding if specified
|
||||
if num_shards is not None and shard_index is not None:
|
||||
if shard_index >= num_shards:
|
||||
raise ValueError(f"Shard index {shard_index} >= num_shards {num_shards}")
|
||||
|
||||
# Calculate shard splits
|
||||
total_episodes = builder.info.splits["train"].num_examples
|
||||
episodes_per_shard = total_episodes // num_shards
|
||||
start_idx = shard_index * episodes_per_shard
|
||||
if shard_index == num_shards - 1:
|
||||
# Last shard gets remaining episodes
|
||||
end_idx = total_episodes
|
||||
else:
|
||||
end_idx = start_idx + episodes_per_shard
|
||||
|
||||
split_str = f"train[{start_idx}:{end_idx}]"
|
||||
raw_dataset = builder.as_dataset(split=split_str)
|
||||
else:
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
# Apply filtering (e.g., success filter for kuka)
|
||||
if dataset_name == "kuka":
|
||||
raw_dataset = raw_dataset.filter(lambda e: e["success"])
|
||||
|
||||
# Apply transformations
|
||||
raw_dataset = raw_dataset.map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
||||
|
||||
# Get dataset configuration
|
||||
fps = DEFAULT_FPS
|
||||
robot_type = DEFAULT_ROBOT_TYPE
|
||||
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
config = OXE_DATASET_CONFIGS[dataset_name]
|
||||
fps = config.get("control_frequency", DEFAULT_FPS)
|
||||
robot_type = config.get("robot_type", DEFAULT_ROBOT_TYPE)
|
||||
robot_type = robot_type.lower().replace(" ", "_").replace("-", "_")
|
||||
|
||||
# Generate features schema
|
||||
features = generate_features_from_builder(builder, dataset_name)
|
||||
|
||||
# Create LeRobot dataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
robot_type=robot_type,
|
||||
fps=int(fps),
|
||||
features=features,
|
||||
)
|
||||
|
||||
# Process episodes
|
||||
start_time = time.time()
|
||||
num_episodes = raw_dataset.cardinality().numpy().item()
|
||||
logging.info(f"Number of episodes: {num_episodes}")
|
||||
|
||||
for episode_index, episode in enumerate(raw_dataset):
|
||||
elapsed_time = time.time() - start_time
|
||||
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||
|
||||
logging.info(
|
||||
f"{episode_index} / {num_episodes} episodes processed "
|
||||
f"(after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
|
||||
)
|
||||
|
||||
# Generate and add frames
|
||||
for frame in generate_lerobot_frames(episode):
|
||||
lerobot_dataset.add_frame(frame)
|
||||
|
||||
lerobot_dataset.save_episode()
|
||||
logging.info("Save_episode")
|
||||
|
||||
# Push to hub if requested
|
||||
if push_to_hub:
|
||||
tags = ["openx", dataset_name]
|
||||
if robot_type != "unknown":
|
||||
tags.append(robot_type)
|
||||
|
||||
lerobot_dataset.push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
)
|
||||
|
||||
|
||||
def validate_dataset(repo_id):
|
||||
"""Sanity check that ensures metadata can be loaded and all files are present."""
|
||||
meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
if meta.total_episodes == 0:
|
||||
raise ValueError("Number of episodes is 0.")
|
||||
|
||||
for ep_idx in range(meta.total_episodes):
|
||||
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
||||
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Parquet file is missing in: {data_path}")
|
||||
|
||||
for vid_key in meta.video_keys:
|
||||
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not vid_path.exists():
|
||||
raise ValueError(f"Video file is missing in: {vid_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of shards to split the dataset into for parallel processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard-index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Index of the shard to process (0-indexed).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
port_rlds(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -20,7 +20,7 @@ from pathlib import Path
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class PortDroidShards(PipelineStep):
|
||||
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||
from port_droid import port_droid, validate_dataset
|
||||
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -61,13 +61,71 @@ class PortDroidShards(PipelineStep):
|
||||
validate_dataset(shard_repo_id)
|
||||
|
||||
|
||||
class PortRLDSShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str = None,
|
||||
num_shards: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
self.num_shards = num_shards
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from port_rlds import port_rlds, validate_dataset
|
||||
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
|
||||
try:
|
||||
validate_dataset(shard_repo_id)
|
||||
return
|
||||
except Exception:
|
||||
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
|
||||
|
||||
port_rlds(
|
||||
self.raw_dir,
|
||||
shard_repo_id,
|
||||
push_to_hub=False,
|
||||
num_shards=world_size,
|
||||
shard_index=rank,
|
||||
)
|
||||
|
||||
validate_dataset(shard_repo_id)
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
raw_dir,
|
||||
repo_id,
|
||||
job_name,
|
||||
logs_dir,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
slurm=True,
|
||||
dataset_type="droid",
|
||||
num_shards=None,
|
||||
):
|
||||
# Select appropriate pipeline step based on dataset type
|
||||
if dataset_type.lower() == "droid":
|
||||
pipeline_step = PortDroidShards(raw_dir, repo_id)
|
||||
default_shards = DROID_SHARDS
|
||||
elif dataset_type.lower() == "rlds":
|
||||
pipeline_step = PortRLDSShards(raw_dir, repo_id, num_shards)
|
||||
default_shards = num_shards or workers # Use num_shards or fallback to workers
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortDroidShards(raw_dir, repo_id),
|
||||
],
|
||||
"pipeline": [pipeline_step],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
@@ -75,7 +133,7 @@ def make_port_executor(
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"tasks": default_shards,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
@@ -113,13 +171,21 @@ def main():
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
default=Path("./logs"),
|
||||
help="Path to logs directory for `datatrove` (default: ./logs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-type",
|
||||
type=str,
|
||||
choices=["droid", "rlds"],
|
||||
default="droid",
|
||||
help="Type of dataset to process: 'droid' for DROID datasets or 'rlds' for RLDS/OpenX datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
default=None,
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory. Defaults to 'port_{dataset_type}'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
@@ -130,8 +196,14 @@ def main():
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
default=None,
|
||||
help="Number of slurm workers. Defaults: 2048 for DROID, 64 for RLDS datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of shards to split the dataset into. For DROID datasets, this is fixed at 2048. For RLDS datasets, defaults to number of workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
@@ -152,8 +224,21 @@ def main():
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set defaults based on dataset type
|
||||
if args.job_name is None:
|
||||
args.job_name = f"port_{args.dataset_type}"
|
||||
|
||||
if args.workers is None:
|
||||
if args.dataset_type == "droid":
|
||||
args.workers = 2048
|
||||
else: # rlds
|
||||
args.workers = 64
|
||||
|
||||
# Convert args to kwargs and process
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
@@ -105,7 +105,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
@@ -141,7 +140,6 @@ all = [
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
class Reachy2CameraConfig(CameraConfig):
|
||||
"""Configuration class for Reachy 2 camera devices.
|
||||
|
||||
This class provides configuration options for Reachy 2 cameras,
|
||||
supporting both the teleop and depth cameras. It includes settings
|
||||
for resolution, frame rate, color mode, and the selection of the cameras.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address="192.168.0.200", # IP address of the robot
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
) # Left teleop camera, 640x480 @ 15FPS
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: Name of the camera device. Can be "teleop" or "depth".
|
||||
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
|
||||
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
|
||||
fps: Requested frames per second for the color stream.
|
||||
width: Requested frame width in pixels for the color stream.
|
||||
height: Requested frame height in pixels for the color stream.
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
ip_address: IP address of the robot. Defaults to "localhost".
|
||||
port: Port number for the camera server. Defaults to 50065.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
image_type: str
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
ip_address: str | None = "localhost"
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
self.name == "depth" and self.image_type not in ["rgb", "depth"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
@@ -1,288 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Reachy2Camera(Camera):
|
||||
"""
|
||||
Manages Reachy 2 camera using Reachy 2 CameraManager.
|
||||
|
||||
This class provides a high-level interface to connect to, configure, and read
|
||||
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
|
||||
frame reading.
|
||||
|
||||
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
|
||||
type (e.g., "left") to be specified in the configuration.
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
overridden in the configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Reachy2CameraConfig):
|
||||
"""
|
||||
Initializes the Reachy2Camera instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.fps = config.fps
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
|
||||
self.cam_manager.initialize_cameras()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available Reachy 2 cameras.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains 'name', 'stereo',
|
||||
and the default profile properties (width, height, fps).
|
||||
"""
|
||||
initialized_cameras = []
|
||||
camera_manager = CameraManager(host=ip_address, port=port)
|
||||
|
||||
for camera in [camera_manager.teleop, camera_manager.depth]:
|
||||
if camera is None:
|
||||
continue
|
||||
|
||||
height, width, _, _, _, _, _ = camera.get_parameters()
|
||||
|
||||
camera_info = {
|
||||
"name": camera._cam_info.name,
|
||||
"stereo": camera._cam_info.stereo,
|
||||
"default_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"fps": 30,
|
||||
},
|
||||
}
|
||||
initialized_cameras.append(camera_info)
|
||||
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
color mode and applying any configured rotation.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
else:
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
(height, width, channels), processed according to configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -37,14 +37,8 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
from .realsense.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
|
||||
elif cfg.type == "reachy2_camera":
|
||||
from .reachy2_camera.reachy2_camera import Reachy2Camera
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -52,8 +52,3 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||
|
||||
|
||||
# streaming datasets
|
||||
LOOKBACK_BACKTRACKTABLE = 100
|
||||
LOOKAHEAD_BACKTRACKTABLE = 100
|
||||
|
||||
@@ -39,7 +39,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files
|
||||
from lerobot.datasets.video_utils import concat_video_files
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -298,9 +298,12 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
|
||||
|
||||
# Append to existing video file
|
||||
concatenate_video_files(
|
||||
concat_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
dst_meta.root,
|
||||
key,
|
||||
chunk_idx,
|
||||
file_idx,
|
||||
)
|
||||
# Update the latest_duration when appending (shifts timestamps!)
|
||||
update_latest_duration = not update_latest_duration
|
||||
|
||||
@@ -14,11 +14,43 @@
|
||||
|
||||
import packaging.version
|
||||
|
||||
V30_MESSAGE = """
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v3.0 which is not backward compatible with v2.1.
|
||||
Please, update your dataset to the new format using this command:
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V30_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
@@ -25,7 +25,6 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
|
||||
IMAGENET_STATS = {
|
||||
@@ -88,26 +87,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
|
||||
@@ -29,6 +29,7 @@ import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
@@ -72,7 +73,7 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
concat_video_files,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_safe_default_codec,
|
||||
@@ -128,10 +129,6 @@ class LeRobotDatasetMetadata:
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
@@ -349,26 +346,21 @@ class LeRobotDatasetMetadata:
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
for key in self.video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(
|
||||
video_key=video_key, chunk_index=0, file_index=0
|
||||
)
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
@@ -473,7 +465,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -584,8 +575,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -597,8 +586,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -674,10 +661,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if tag_version:
|
||||
with contextlib.suppress(RevisionNotFoundError):
|
||||
@@ -969,10 +957,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
Video encoding is handled automatically based on batch_encoding_size:
|
||||
- If batch_encoding_size == 1: Videos are encoded immediately after each episode
|
||||
- If batch_encoding_size > 1: Videos are encoded in batches.
|
||||
|
||||
Args:
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
@@ -1009,81 +993,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
# Check if we should trigger batch encoding
|
||||
self.episodes_since_last_encoding += 1
|
||||
if self.episodes_since_last_encoding == self.batch_encoding_size:
|
||||
start_ep = self.num_episodes - self.batch_encoding_size
|
||||
end_ep = self.num_episodes
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
self.episodes_since_last_encoding = 0
|
||||
|
||||
if not episode_data:
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None):
|
||||
"""
|
||||
Batch save videos for multiple episodes.
|
||||
|
||||
Args:
|
||||
start_episode: Starting episode index (inclusive)
|
||||
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
|
||||
"""
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[start_episode]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding videos for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
or self.meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||
):
|
||||
# The current episode is in a new chunk or file.
|
||||
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
# Load new episode dataframe
|
||||
chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[ep_idx]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
# Save the current episode's video metadata to the dataframe
|
||||
video_ep_metadata = {}
|
||||
for video_key in self.meta.video_keys:
|
||||
video_ep_metadata.update(self._save_episode_video(video_key, ep_idx))
|
||||
video_ep_metadata.pop("episode_index")
|
||||
video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(video_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
# Reset episode buffer and clean up temporary images
|
||||
self.clear_episode_buffer()
|
||||
|
||||
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
|
||||
@@ -1164,10 +1082,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
if self.meta.episodes is None or (
|
||||
f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names
|
||||
or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names
|
||||
):
|
||||
if self.meta.episodes is None:
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
latest_duration_in_s = 0.0
|
||||
@@ -1177,8 +1092,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated video file (possibly several episodes ago)
|
||||
latest_ep = self.meta.episodes[episode_index - 1]
|
||||
# Retrieve information from the latest video file
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"]
|
||||
|
||||
@@ -1199,19 +1114,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest video file
|
||||
concatenate_video_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
# Update video info (only needed when first episode is encoded since it reads from episode 0)
|
||||
if episode_index == 0:
|
||||
self.meta.update_video_info(video_key)
|
||||
write_info(self.meta.info, self.meta.root) # ensure video info always written properly
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"videos/{video_key}/chunk_index": chunk_idx,
|
||||
@@ -1221,17 +1128,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
if self.image_writer is not None:
|
||||
self._wait_image_writer()
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
def clear_episode_buffer(self) -> None:
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
img_dir = self.root / "images" / cam_key
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
@@ -1272,7 +1172,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
@@ -1288,7 +1187,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
@@ -1305,8 +1203,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
@@ -1,535 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
LookAheadError,
|
||||
LookBackError,
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
get_delta_indices,
|
||||
is_float_in_list,
|
||||
item_to_torch,
|
||||
safe_shard,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
|
||||
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||
|
||||
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||
dataset into memory.
|
||||
|
||||
Example:
|
||||
Basic usage:
|
||||
```python
|
||||
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
# Create a streaming dataset with delta timestamps
|
||||
delta_timestamps = {
|
||||
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(
|
||||
repo_id="your-dataset-repo-id",
|
||||
delta_timestamps=delta_timestamps,
|
||||
streaming=True,
|
||||
buffer_size=1000,
|
||||
)
|
||||
|
||||
# Iterate over the dataset
|
||||
for i, item in enumerate(dataset):
|
||||
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||
# item will contain stacked frames according to delta_timestamps
|
||||
if i >= 10:
|
||||
break
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
buffer_size: int = 1000,
|
||||
max_num_shards: int = 16,
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list.
|
||||
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
|
||||
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.streaming_from_local = root is not None
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.seed = seed
|
||||
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
self.delta_timestamps = None
|
||||
self.delta_indices = None
|
||||
|
||||
if delta_timestamps is not None:
|
||||
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
return self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return self.meta.fps
|
||||
|
||||
@staticmethod
|
||||
def _iter_random_indices(
|
||||
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
||||
) -> Iterator[int]:
|
||||
while True:
|
||||
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||
|
||||
@staticmethod
|
||||
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
# the logic is to add 2 levels of randomness:
|
||||
# (1) sample one shard at random from the ones available, and
|
||||
# (2) sample one frame from the shard sampled at (1)
|
||||
frames_buffer = []
|
||||
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
||||
|
||||
try:
|
||||
for frame in self.make_frame(backtrack_dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||
yield frames_buffer[i]
|
||||
frames_buffer[i] = frame
|
||||
else:
|
||||
frames_buffer.append(frame)
|
||||
break # random shard sampled, switch shard
|
||||
except (
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
if delta_timestamps is None:
|
||||
return 1, 1
|
||||
|
||||
if not dynamic_bounds:
|
||||
# Fix the windows
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
|
||||
return lookback, lookahead
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
) -> dict[str, list[float]]:
|
||||
if indices is not None:
|
||||
return {
|
||||
key: (
|
||||
start_ts + torch.tensor(indices[key]) / self.fps
|
||||
).tolist() # NOTE: why not delta_timestamps directly?
|
||||
for key in self.delta_timestamps
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.meta.video_keys, [start_ts])
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
video_frames: dict[str, torch.Tensor],
|
||||
query_timestamps: dict[str, list[float]],
|
||||
original_timestamps: dict[str, list[float]],
|
||||
) -> dict[str, torch.BoolTensor]:
|
||||
padding_mask = {}
|
||||
|
||||
for video_key, timestamps in original_timestamps.items():
|
||||
if video_key not in video_frames:
|
||||
continue # only padding on video keys that are available
|
||||
frames = []
|
||||
mask = []
|
||||
padding_frame = self._make_padding_camera_frame(video_key)
|
||||
for ts in timestamps:
|
||||
if is_float_in_list(ts, query_timestamps[video_key]):
|
||||
idx = find_float_index(ts, query_timestamps[video_key])
|
||||
frames.append(video_frames[video_key][idx, :])
|
||||
mask.append(False)
|
||||
else:
|
||||
frames.append(padding_frame)
|
||||
mask.append(True)
|
||||
|
||||
padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(
|
||||
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
|
||||
) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
updates.append(padding_mask)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = keys_to_timestamps[key]
|
||||
# Clamp out timesteps outside of episode boundaries
|
||||
query_timestamps[key] = torch.clamp(
|
||||
torch.tensor(timestamps), *episode_boundaries_ts[key]
|
||||
).tolist()
|
||||
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
return item
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
# TODO(fracapuano): Modularize this function, refactor the code
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
current_item (dict): Current item from the iterator.
|
||||
ep_idx (int): Episode index.
|
||||
|
||||
Returns:
|
||||
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||
"""
|
||||
current_episode_idx = current_item["episode_index"]
|
||||
|
||||
# Prepare results
|
||||
query_result = {}
|
||||
padding = {}
|
||||
|
||||
for key, delta_indices in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||
delta_results = {}
|
||||
|
||||
# Separate and sort deltas by difficulty (easier operations first)
|
||||
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||
zero_deltas = [d for d in delta_indices if d == 0]
|
||||
|
||||
# Process zero deltas (current frame)
|
||||
for delta in zero_deltas:
|
||||
delta_results[delta] = (
|
||||
current_item[key],
|
||||
False,
|
||||
)
|
||||
|
||||
# Process negative deltas in order of increasing difficulty
|
||||
lookback_failed = False
|
||||
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in negative_deltas:
|
||||
if lookback_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (past_item[key], False)
|
||||
last_successful_frame = past_item[key]
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookback_failed = True # All subsequent negative deltas will also fail
|
||||
|
||||
# Process positive deltas in order of increasing difficulty
|
||||
lookahead_failed = False
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in positive_deltas:
|
||||
if lookahead_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (future_item[key], False)
|
||||
last_successful_frame = future_item[key]
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||
|
||||
# Reconstruct original order for stacking
|
||||
for delta in delta_indices:
|
||||
frame, is_padded = delta_results[delta]
|
||||
|
||||
# add batch dimension for stacking
|
||||
target_frames.append(frame) # frame.unsqueeze(0))
|
||||
is_pad.append(is_padded)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
query_result[key] = torch.stack(target_frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
|
||||
return query_result, padding
|
||||
|
||||
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||
"""
|
||||
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
||||
"""
|
||||
if delta_timestamps is None:
|
||||
return
|
||||
|
||||
# Get all available feature keys from the dataset metadata
|
||||
available_features = set(self.meta.features.keys())
|
||||
|
||||
# Get all keys from delta_timestamps
|
||||
delta_keys = set(delta_timestamps.keys())
|
||||
|
||||
# Find any keys that don't correspond to features
|
||||
invalid_keys = delta_keys - available_features
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
||||
f"Available features are: {sorted(available_features)}"
|
||||
)
|
||||
@@ -17,11 +17,10 @@ import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -87,8 +86,6 @@ DEFAULT_FEATURES = {
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
@@ -779,230 +776,3 @@ def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||
|
||||
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
item (dict): Dictionary of items from a dataset.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
|
||||
def is_float_in_list(target, float_list, threshold=1e-6):
|
||||
return any(abs(target - x) <= threshold for x in float_list)
|
||||
|
||||
|
||||
def find_float_index(target, float_list, threshold=1e-6):
|
||||
for i, x in enumerate(float_list):
|
||||
if abs(target - x) <= threshold:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable(Generic[T]):
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def lookahead_buffer(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the current lookahead buffer.
|
||||
"""
|
||||
return list(self._ahead_buf)
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def reset_cursor(self) -> None:
|
||||
"""
|
||||
Reset cursor to the most recent position (equivalent to calling next()
|
||||
until you're back to the latest item).
|
||||
"""
|
||||
self._cursor = 0
|
||||
|
||||
def clear_ahead_buffer(self) -> None:
|
||||
"""
|
||||
Clear the ahead buffer, discarding any pre-fetched items.
|
||||
"""
|
||||
self._ahead_buf.clear()
|
||||
|
||||
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
|
||||
"""
|
||||
Switch the source of the backtrackable to a new iterable, keeping the history.
|
||||
|
||||
This is useful when iterating over a sequence of datasets. The history from the
|
||||
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
|
||||
to the present.
|
||||
"""
|
||||
self._source = iter(new_source)
|
||||
self.clear_ahead_buffer()
|
||||
self.reset_cursor()
|
||||
|
||||
|
||||
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
|
||||
"""
|
||||
Safe shards the dataset.
|
||||
"""
|
||||
shard_idx = min(dataset.num_shards, index + 1) - 1
|
||||
|
||||
return dataset.shard(num_shards, index=shard_idx)
|
||||
|
||||
@@ -70,7 +70,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concat_video_files, get_video_duration_in_s
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
@@ -204,8 +204,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
@@ -288,11 +287,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_video_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||
|
||||
# Update episodes metadata for the file we just saved
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
@@ -324,11 +319,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_video_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||
|
||||
# Update episodes metadata for the final file
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
|
||||
@@ -17,21 +17,22 @@ import glob
|
||||
import importlib
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_PATH
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
@@ -171,68 +172,15 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
return self._cache[video_path][0]
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the number of cached decoders."""
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
class FrameTimestampError(ValueError):
|
||||
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_default_decoder_cache = VideoDecoderCache()
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
timestamps: List of timestamps to extract frames.
|
||||
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
||||
log_loaded_timestamps: Whether to log loaded timestamps.
|
||||
decoder_cache: Optional decoder cache instance. Uses default if None.
|
||||
|
||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
@@ -241,24 +189,27 @@ def decode_video_frames_torchcodec(
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
loaded_ts = []
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
loaded_frames = []
|
||||
|
||||
loaded_ts = []
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
@@ -289,14 +240,10 @@ def decode_video_frames_torchcodec(
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
|
||||
if not len(timestamps) == len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||
)
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -320,10 +267,6 @@ def encode_video_frames(
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
|
||||
if video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
return
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
@@ -392,87 +335,60 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
):
|
||||
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
Concatenate multiple video files into a single video file using ffmpeg.
|
||||
|
||||
This function takes a list of video input file paths and concatenates them into a single
|
||||
This function takes a list of video file paths and concatenates them into a single
|
||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
concatenation without re-encoding.
|
||||
|
||||
Args:
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
paths_to_cat: List of video file paths to concatenate, in order.
|
||||
root: Root directory where temporary files and output will be created.
|
||||
video_key: Video key identifier (e.g., camera name) used in output path.
|
||||
chunk_idx: Chunk index for organizing output files.
|
||||
file_idx: File index within the chunk.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
||||
codec, resolution, and frame rate for proper concatenation.
|
||||
- Output path follows the DEFAULT_VIDEO_PATH pattern with video_key, chunk_idx,
|
||||
and file_idx parameters.
|
||||
- This function uses subprocess to call ffmpeg directly because PyAV doesn't have
|
||||
built-in support for video concatenation. The concat demuxer in ffmpeg handles
|
||||
all the complex timestamp adjustments automatically.
|
||||
"""
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
tmp_dir = Path(tempfile.mkdtemp(dir=root))
|
||||
path_concat_video_files = tmp_dir / "concat_video_files.txt"
|
||||
with open(path_concat_video_files, "w") as f:
|
||||
for ep_path in paths_to_cat:
|
||||
f.write(f"file '{str(ep_path)}'\n")
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
return
|
||||
path_tmp_output = tmp_dir / "tmp_output.mp4"
|
||||
command = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"concat",
|
||||
"-safe",
|
||||
"0",
|
||||
"-i",
|
||||
str(path_concat_video_files),
|
||||
"-c",
|
||||
"copy",
|
||||
str(path_tmp_output),
|
||||
]
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path)}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
|
||||
# Create input and output containers
|
||||
input_container = av.open(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
# Replicate input streams in output container
|
||||
stream_map = {}
|
||||
for input_stream in input_container.streams:
|
||||
if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams
|
||||
stream_map[input_stream.index] = output_container.add_stream_from_template(
|
||||
template=input_stream, opaque=True
|
||||
)
|
||||
stream_map[
|
||||
input_stream.index
|
||||
].time_base = (
|
||||
input_stream.time_base
|
||||
) # set the time base to the input stream time base (missing in the codec context)
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
for packet in input_container.demux():
|
||||
# Skip packets from un-mapped streams
|
||||
if packet.stream.index not in stream_map:
|
||||
continue
|
||||
|
||||
# Skip demux flushing packets
|
||||
if packet.dts is None:
|
||||
continue
|
||||
|
||||
output_stream = stream_map[packet.stream.index]
|
||||
packet.stream = output_stream
|
||||
output_container.mux(packet)
|
||||
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
output_path = root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(path_tmp_output), str(output_path))
|
||||
shutil.rmtree(str(tmp_dir))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -618,66 +534,3 @@ def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
# Fallback to container duration if stream duration is not available
|
||||
duration = float(container.duration / av.time_base)
|
||||
return duration
|
||||
|
||||
|
||||
class VideoEncodingManager:
|
||||
"""
|
||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||
|
||||
This manager handles:
|
||||
- Batch encoding for any remaining episodes when recording interrupted
|
||||
- Cleaning up temporary image files from interrupted episodes
|
||||
- Removing empty image directories
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset instance
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if self.dataset.episodes_since_last_encoding > 0:
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
logging.info("Recording stopped. Encoding remaining episodes...")
|
||||
|
||||
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
||||
end_ep = self.dataset.num_episodes
|
||||
logging.info(
|
||||
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
for key in self.dataset.meta.video_keys:
|
||||
img_dir = self.dataset._get_image_file_path(
|
||||
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
if img_dir.exists():
|
||||
logging.debug(
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
# Check for any remaining PNG files
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Only remove the images directory if no PNG files remain
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -73,7 +73,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
@@ -209,14 +208,7 @@ def record_loop(
|
||||
(
|
||||
t
|
||||
for t in teleop
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
so100_leader.SO100Leader,
|
||||
so101_leader.SO101Leader,
|
||||
koch_leader.KochLeader,
|
||||
),
|
||||
)
|
||||
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -309,7 +301,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
@@ -330,7 +321,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
@@ -342,47 +332,46 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
|
||||
@@ -55,7 +55,6 @@ from lerobot.robots import ( # noqa: F401
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
reachy2,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
@@ -93,15 +92,11 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
|
||||
@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Optional
|
||||
left_arm_disable_torque_on_disconnect: bool = True
|
||||
left_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
left_arm_max_relative_target: int | None = None
|
||||
left_arm_use_degrees: bool = False
|
||||
right_arm_disable_torque_on_disconnect: bool = True
|
||||
right_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
right_arm_max_relative_target: int | None = None
|
||||
right_arm_use_degrees: bool = False
|
||||
|
||||
# cameras (shared between both arms)
|
||||
|
||||
@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -110,7 +110,6 @@ class KochFollower(Robot):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -121,6 +120,7 @@ class KochFollower(Robot):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
from .robot_reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.cameras.configs import ColorMode
|
||||
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("reachy2")
|
||||
@dataclass
|
||||
class Reachy2RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors.
|
||||
max_relative_target: float | None = None
|
||||
|
||||
# IP address of the Reachy 2 robot
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
|
||||
disable_torque_on_disconnect: bool = False
|
||||
|
||||
# Tag for external commands control
|
||||
# Set to True if you use an external commands system to control the robot,
|
||||
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
|
||||
# If True, robot.send_action() will not send commands to the robot.
|
||||
use_external_commands: bool = False
|
||||
|
||||
# Robot parts
|
||||
# Set to False to not add the corresponding joints part to the robot list of joints.
|
||||
# By default, all parts are set to True.
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
# Robot cameras
|
||||
# Set to True if you want to use the corresponding cameras in the observations.
|
||||
# By default, only the teleop cameras are used.
|
||||
with_left_teleop_camera: bool = True
|
||||
with_right_teleop_camera: bool = True
|
||||
with_torso_camera: bool = False
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add cameras with same ip_address as the robot
|
||||
if self.with_left_teleop_camera:
|
||||
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_right_teleop_camera:
|
||||
self.cameras["teleop_right"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="right",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_torso_camera:
|
||||
self.cameras["torso_rgb"] = Reachy2CameraConfig(
|
||||
name="depth",
|
||||
image_type="rgb",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Robot part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -1,230 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Robot(Robot):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2RobotConfig
|
||||
name = "reachy2"
|
||||
|
||||
def __init__(self, config: Reachy2RobotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.use_external_commands = self.config.use_external_commands
|
||||
|
||||
self.reachy: None | ReachySDK = None
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.logs: dict[str, float] = {}
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
|
||||
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
|
||||
|
||||
@property
|
||||
def motors_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = False) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
def configure(self) -> None:
|
||||
if self.reachy is not None:
|
||||
self.reachy.turn_on()
|
||||
self.reachy.reset_default_limits()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
def _get_state(self) -> dict[str, float]:
|
||||
if self.reachy is not None:
|
||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
|
||||
if not self.config.with_mobile_base:
|
||||
return pos_dict
|
||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
return {**pos_dict, **vel_dict}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
before_read_t = time.perf_counter()
|
||||
obs_dict.update(self._get_state())
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
vel = {}
|
||||
goal_pos = {}
|
||||
for key, val in action.items():
|
||||
if key not in self.joints_dict:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = float(val)
|
||||
else:
|
||||
if not self.use_external_commands and self.config.max_relative_target is not None:
|
||||
goal_pos[key] = float(val)
|
||||
goal_present_pos = {
|
||||
key: (
|
||||
goal_pos[key],
|
||||
self.reachy.joints[self.joints_dict[key]].present_position,
|
||||
)
|
||||
}
|
||||
safe_goal_pos = ensure_safe_goal_position(
|
||||
goal_present_pos, float(self.config.max_relative_target)
|
||||
)
|
||||
val = safe_goal_pos[key]
|
||||
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
|
||||
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||
|
||||
# We don't send the goal positions if we control Reachy 2 externally
|
||||
if not self.use_external_commands:
|
||||
self.reachy.send_goal_positions()
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy is not None:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
if self.config.disable_torque_on_disconnect:
|
||||
self.reachy.turn_off_smoothly()
|
||||
self.reachy.disconnect()
|
||||
@@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -161,11 +161,6 @@ class SO100Follower(Robot):
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 32)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
|
||||
@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -157,13 +157,6 @@ class SO101Follower(Robot):
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 32)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write(
|
||||
"Max_Torque_Limit", motor, 500
|
||||
) # 50% of the max torque limit to avoid burnout
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
|
||||
@@ -24,6 +24,11 @@ from ..config import RobotConfig
|
||||
@RobotConfig.register_subclass("stretch3")
|
||||
@dataclass
|
||||
class Stretch3RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
|
||||
@@ -61,10 +61,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .bi_so100_follower import BiSO100Follower
|
||||
|
||||
return BiSO100Follower(config)
|
||||
elif config.type == "reachy2":
|
||||
from .reachy2 import Reachy2Robot
|
||||
|
||||
return Reachy2Robot(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
@@ -74,7 +70,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
) -> dict[str, float]:
|
||||
"""Caps relative action target magnitude for safety."""
|
||||
|
||||
|
||||
@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: float | dict[str, float] = 5.0
|
||||
max_relative_target: int | None = 5
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
|
||||
def profile_throughput_indexed(
|
||||
dataset: LeRobotDataset, num_samples: int, warmup_iters: int = 3
|
||||
) -> np.ndarray:
|
||||
"""Measure per-item access time on an indexable LeRobotDataset.
|
||||
|
||||
Accesses dataset[i % len(dataset)] for ``num_samples`` iterations, with an initial warmup.
|
||||
"""
|
||||
next_times = np.zeros(num_samples)
|
||||
total = len(dataset)
|
||||
|
||||
# warmup
|
||||
for k in range(warmup_iters):
|
||||
_ = dataset[k % total]
|
||||
|
||||
for j in tqdm(range(num_samples), desc="Profiling dataset throughput", unit="item"):
|
||||
start_time = time.perf_counter()
|
||||
_ = dataset[j % total]
|
||||
end_time = time.perf_counter()
|
||||
next_times[j] = end_time - start_time
|
||||
|
||||
return next_times
|
||||
|
||||
|
||||
def profile_throughput(
|
||||
dataset: StreamingLeRobotDataset, num_samples: int, warmup_iters: int = 3
|
||||
) -> np.ndarray:
|
||||
"""Measure ``.next()`` call latency on a streaming dataset.
|
||||
|
||||
Performs a configurable warmup. This does not numerically "normalize" times; it simply
|
||||
avoids including initialization overhead in the timing window.
|
||||
"""
|
||||
next_times = np.zeros(num_samples)
|
||||
iter_dataset = iter(dataset)
|
||||
|
||||
# warmup
|
||||
for _ in range(warmup_iters):
|
||||
_ = next(iter_dataset)
|
||||
|
||||
for j in tqdm(range(num_samples), desc="Profiling throughput", unit="call"):
|
||||
start_time = time.perf_counter()
|
||||
_sample = next(iter_dataset)
|
||||
end_time = time.perf_counter()
|
||||
next_times[j] = end_time - start_time
|
||||
|
||||
return next_times
|
||||
|
||||
|
||||
def profile_init(dataset_factory: Callable[[], StreamingLeRobotDataset], num_runs: int) -> np.ndarray:
|
||||
"""Measure time-to-first-sample by re-instantiating the dataset ``num_runs`` times.
|
||||
|
||||
Using a factory avoids unsafe ``deepcopy`` of objects that may own threads or file handles.
|
||||
"""
|
||||
init_times = np.zeros(num_runs)
|
||||
for i in tqdm(range(num_runs), desc="Profiling init", unit="run"):
|
||||
fresh_dataset = dataset_factory()
|
||||
iter_dataset = iter(fresh_dataset)
|
||||
start_time = time.perf_counter()
|
||||
_ = next(iter_dataset)
|
||||
end_time = time.perf_counter()
|
||||
init_times[i] = end_time - start_time
|
||||
|
||||
return init_times
|
||||
|
||||
|
||||
def profile_randomness(dataset: StreamingLeRobotDataset, num_samples: int) -> float:
|
||||
"""Measure how random the sample order is via correlation.
|
||||
|
||||
Returns a Pearson correlation between retrieved frame index and iteration index.
|
||||
- ~0: random order
|
||||
- ~+1: strictly increasing (in-order)
|
||||
- ~-1: strictly decreasing (reverse order)
|
||||
"""
|
||||
frame_indices = np.zeros(num_samples, dtype=float)
|
||||
iter_indices = np.arange(num_samples, dtype=float)
|
||||
|
||||
iter_dataset = iter(dataset)
|
||||
|
||||
for i in tqdm(range(num_samples), desc="Profiling randomness", unit="sample"):
|
||||
sample = next(iter_dataset)
|
||||
if "index" in sample:
|
||||
frame_idx_value = sample["index"]
|
||||
elif "frame_index" in sample:
|
||||
frame_idx_value = sample["frame_index"]
|
||||
else:
|
||||
raise KeyError("Sample is missing 'index' (or 'frame_index') required to compute randomness.")
|
||||
frame_indices[i] = float(frame_idx_value)
|
||||
|
||||
# Guard against degenerate cases
|
||||
if num_samples < 2 or np.std(frame_indices) == 0 or np.std(iter_indices) == 0:
|
||||
return np.nan, None
|
||||
|
||||
correlation = float(np.corrcoef(frame_indices, iter_indices)[0, 1])
|
||||
return correlation
|
||||
|
||||
|
||||
def profile_streaming_dataset(
|
||||
repo_id: str,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
num_samples: int = 100,
|
||||
warmup_iters: int = 10,
|
||||
buffer_size: int = 1000,
|
||||
) -> tuple[np.ndarray, np.ndarray, float]:
|
||||
"""Run init, throughput, and randomness profiles on a StreamingLeRobotDataset."""
|
||||
|
||||
def dataset_factory() -> StreamingLeRobotDataset:
|
||||
return StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps, buffer_size=buffer_size)
|
||||
|
||||
# Measure init by repeated instantiation
|
||||
init_times = profile_init(dataset_factory, num_runs=warmup_iters)
|
||||
|
||||
# Throughput and randomness on a single fresh dataset instance
|
||||
dataset = dataset_factory()
|
||||
next_times = profile_throughput(dataset, num_samples=num_samples, warmup_iters=warmup_iters)
|
||||
correlation = profile_randomness(dataset, num_samples=num_samples)
|
||||
|
||||
return init_times, next_times, correlation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Profile StreamingLeRobotDataset performance metrics.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/svla_so101_pickplace",
|
||||
help="Dataset repo_id to profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of samples to measure for throughput/randomness.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations for init and throughput warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Buffer size for the streaming dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with-delta-timestamps",
|
||||
action="store_true",
|
||||
help="Profile with delta timestamps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare-with-local",
|
||||
action="store_true",
|
||||
help="Also profile local LeRobotDataset throughput for comparison.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
default=os.path.join("outputs", "benchmarks"),
|
||||
help="Directory to write CSVs/PNGs to.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
delta_timestamps = (
|
||||
None
|
||||
if not args.with_delta_timestamps
|
||||
else {
|
||||
"observation.state": [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0],
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
init_times, next_times, correlation = profile_streaming_dataset(
|
||||
repo_id=args.repo_id,
|
||||
delta_timestamps=delta_timestamps,
|
||||
num_samples=args.num_samples,
|
||||
warmup_iters=args.warmup_iters,
|
||||
buffer_size=args.buffer_size,
|
||||
)
|
||||
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
|
||||
repo_id_str = args.repo_id.replace("/", "-")
|
||||
date_str = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
name_suffix = f"{repo_id_str}_buf{args.buffer_size}_{date_str}"
|
||||
|
||||
# Visualization disabled by default; figures are not created or saved.
|
||||
|
||||
init_df = pd.DataFrame({"init_times": init_times})
|
||||
next_df = pd.DataFrame({"next_times": next_times})
|
||||
correlation_df = pd.DataFrame({"correlation": [correlation]})
|
||||
|
||||
init_df.to_csv(os.path.join(args.outdir, f"init_times_{name_suffix}.csv"), index=False)
|
||||
next_df.to_csv(os.path.join(args.outdir, f"next_times_{name_suffix}.csv"), index=False)
|
||||
correlation_df.to_csv(os.path.join(args.outdir, f"correlation_{name_suffix}.csv"), index=False)
|
||||
|
||||
if args.compare_with_local:
|
||||
# Profile local non-streaming dataset throughput for comparison
|
||||
local_ds = LeRobotDataset(args.repo_id, delta_timestamps=delta_timestamps)
|
||||
local_next_times = profile_throughput_indexed(
|
||||
local_ds, num_samples=args.num_samples, warmup_iters=args.warmup_iters
|
||||
)
|
||||
local_df = pd.DataFrame({"next_times": local_next_times})
|
||||
local_df.to_csv(
|
||||
os.path.join(args.outdir, f"next_times_local_{repo_id_str}_{date_str}.csv"),
|
||||
index=False,
|
||||
)
|
||||
@@ -179,11 +179,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle and not cfg.dataset.streaming,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
@@ -209,9 +208,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
if batch[key].dtype != torch.bool:
|
||||
batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key]
|
||||
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
|
||||
@@ -88,7 +88,6 @@ class KochLeader(Teleoperator):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -99,6 +98,7 @@ class KochLeader(Teleoperator):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
from .reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Teleoperator,
|
||||
)
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
|
||||
@dataclass
|
||||
class Reachy2TeleoperatorConfig(TeleoperatorConfig):
|
||||
# IP address of the Reachy 2 robot used as teleoperator
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# Whether to use the present position of the joints as actions
|
||||
# if False, the goal position of the joints will be used
|
||||
use_present_position: bool = False
|
||||
|
||||
# Which parts of the robot to use
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Teleoperator part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -1,164 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Teleoperator(Teleoperator):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2TeleoperatorConfig
|
||||
name = "reachy2_specific"
|
||||
|
||||
def __init__(self, config: Reachy2TeleoperatorConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.reachy: None | ReachySDK = None
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
|
||||
if self.reachy and self.is_connected:
|
||||
if self.config.use_present_position:
|
||||
joint_action = {
|
||||
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
|
||||
}
|
||||
else:
|
||||
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
|
||||
|
||||
if not self.config.with_mobile_base:
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return joint_action
|
||||
|
||||
if self.config.use_present_position:
|
||||
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
else:
|
||||
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return {**joint_action, **vel_action}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy and self.is_connected:
|
||||
self.reachy.disconnect()
|
||||
@@ -65,9 +65,5 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .bi_so100_leader import BiSO100Leader
|
||||
|
||||
return BiSO100Leader(config)
|
||||
elif config.type == "reachy2_teleoperator":
|
||||
from .reachy2_teleoperator import Reachy2Teleoperator
|
||||
|
||||
return Reachy2Teleoperator(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
PARAMS = [
|
||||
("teleop", "left"),
|
||||
("teleop", "right"),
|
||||
("depth", "rgb"),
|
||||
# ("depth", "depth"), # Depth camera is not available yet
|
||||
]
|
||||
|
||||
|
||||
def _make_cam_manager_mock():
|
||||
c = MagicMock(name="CameraManagerMock")
|
||||
|
||||
teleop = MagicMock(name="TeleopCam")
|
||||
teleop.width = 640
|
||||
teleop.height = 480
|
||||
teleop.get_frame = MagicMock(
|
||||
side_effect=lambda *_, **__: (
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
|
||||
depth = MagicMock(name="DepthCam")
|
||||
depth.width = 640
|
||||
depth.height = 480
|
||||
depth.get_frame = MagicMock(
|
||||
side_effect=lambda *_, **__: (
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
|
||||
c.is_connected.return_value = True
|
||||
c.teleop = teleop
|
||||
c.depth = depth
|
||||
|
||||
def _connect():
|
||||
c.teleop = teleop
|
||||
c.depth = depth
|
||||
c.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
c.teleop = None
|
||||
c.depth = None
|
||||
c.is_connected.return_value = False
|
||||
|
||||
c.connect = MagicMock(side_effect=_connect)
|
||||
c.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
# Mock methods
|
||||
c.initialize_cameras = MagicMock()
|
||||
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=PARAMS,
|
||||
# ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"],
|
||||
ids=["teleop-left", "teleop-right", "torso-rgb"],
|
||||
)
|
||||
def camera(request):
|
||||
name, image_type = request.param
|
||||
with (
|
||||
patch(
|
||||
"lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager",
|
||||
side_effect=lambda *a, **k: _make_cam_manager_mock(),
|
||||
),
|
||||
):
|
||||
config = Reachy2CameraConfig(name=name, image_type=image_type)
|
||||
cam = Reachy2Camera(config)
|
||||
yield cam
|
||||
if cam.is_connected:
|
||||
cam.disconnect()
|
||||
|
||||
|
||||
def test_connect(camera):
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
camera.cam_manager.initialize_cameras.assert_called_once()
|
||||
|
||||
|
||||
def test_read(camera):
|
||||
camera.connect()
|
||||
|
||||
img = camera.read()
|
||||
if camera.config.name == "teleop":
|
||||
camera.cam_manager.teleop.get_frame.assert_called_once()
|
||||
elif camera.config.name == "depth":
|
||||
camera.cam_manager.depth.get_frame.assert_called_once()
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert img.shape == (480, 640, 3)
|
||||
|
||||
|
||||
def test_disconnect(camera):
|
||||
camera.connect()
|
||||
|
||||
camera.disconnect()
|
||||
assert not camera.is_connected
|
||||
|
||||
|
||||
def test_async_read(camera):
|
||||
camera.connect()
|
||||
try:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_timeout(camera):
|
||||
camera.connect()
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_read_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read()
|
||||
|
||||
|
||||
def test_disconnect_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
def test_wrong_camera_name():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
|
||||
|
||||
|
||||
def test_wrong_image_type():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="teleop", image_type="rgb")
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="depth", image_type="left")
|
||||
|
||||
|
||||
def test_wrong_color_mode():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color")
|
||||
@@ -28,7 +28,6 @@ pytest_plugins = [
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
"tests.fixtures.optimizers",
|
||||
"tests.plugins.reachy2_sdk",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -789,269 +789,3 @@ def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
dataset.meta.update_chunk_settings(video_files_size_in_mb=new_video_size)
|
||||
assert dataset.meta.get_chunk_settings()["video_files_size_in_mb"] == new_video_size
|
||||
assert dataset.meta.video_files_size_in_mb == new_video_size
|
||||
|
||||
|
||||
def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that all frames have correct episode indices across multiple episodes."""
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Create 3 episodes with different lengths
|
||||
num_episodes = 3
|
||||
frames_per_episode = [10, 15, 8]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
# Load the dataset and check episode indices
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check specific frames across episode boundaries
|
||||
cumulative = 0
|
||||
for ep_idx, ep_length in enumerate(frames_per_episode):
|
||||
# Check start, middle, and end of each episode
|
||||
start_frame = cumulative
|
||||
middle_frame = cumulative + ep_length // 2
|
||||
end_frame = cumulative + ep_length - 1
|
||||
|
||||
for frame_idx in [start_frame, middle_frame, end_frame]:
|
||||
frame_data = loaded_dataset[frame_idx]
|
||||
actual_ep_idx = frame_data["episode_index"].item()
|
||||
assert actual_ep_idx == ep_idx, (
|
||||
f"Frame {frame_idx} has episode_index {actual_ep_idx}, should be {ep_idx}"
|
||||
)
|
||||
|
||||
cumulative += ep_length
|
||||
|
||||
# Check episode index distribution
|
||||
all_episode_indices = [loaded_dataset[i]["episode_index"].item() for i in range(len(loaded_dataset))]
|
||||
from collections import Counter
|
||||
|
||||
distribution = Counter(all_episode_indices)
|
||||
expected_dist = {i: frames_per_episode[i] for i in range(num_episodes)}
|
||||
|
||||
assert dict(distribution) == expected_dist, (
|
||||
f"Episode distribution {dict(distribution)} != expected {expected_dist}"
|
||||
)
|
||||
|
||||
|
||||
def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test episode metadata consistency across multiple episodes."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
num_episodes = 4
|
||||
frames_per_episode = [20, 35, 10, 25]
|
||||
tasks = ["pick", "place", "pick", "place"]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.save_episode()
|
||||
|
||||
# Load and validate episode metadata
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
assert loaded_dataset.meta.total_episodes == num_episodes
|
||||
assert loaded_dataset.meta.total_frames == sum(frames_per_episode)
|
||||
|
||||
cumulative_frames = 0
|
||||
for episode_idx in range(num_episodes):
|
||||
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
|
||||
|
||||
# Check basic episode properties
|
||||
assert episode_metadata["episode_index"] == episode_idx
|
||||
assert episode_metadata["length"] == frames_per_episode[episode_idx]
|
||||
assert episode_metadata["tasks"] == [tasks[episode_idx]]
|
||||
|
||||
# Check dataset indices
|
||||
expected_from = cumulative_frames
|
||||
expected_to = cumulative_frames + frames_per_episode[episode_idx]
|
||||
|
||||
assert episode_metadata["dataset_from_index"] == expected_from
|
||||
assert episode_metadata["dataset_to_index"] == expected_to
|
||||
|
||||
cumulative_frames += frames_per_episode[episode_idx]
|
||||
|
||||
|
||||
def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that episodes have no gaps or overlaps in their data indices."""
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
num_episodes = 5
|
||||
frames_per_episode = [12, 8, 20, 15, 5]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check data consistency - no gaps or overlaps
|
||||
cumulative_check = 0
|
||||
for episode_idx in range(num_episodes):
|
||||
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
|
||||
from_idx = episode_metadata["dataset_from_index"]
|
||||
to_idx = episode_metadata["dataset_to_index"]
|
||||
|
||||
# Check that episode starts exactly where previous ended
|
||||
assert from_idx == cumulative_check, (
|
||||
f"Episode {episode_idx} starts at {from_idx}, expected {cumulative_check}"
|
||||
)
|
||||
|
||||
# Check that episode length matches expected
|
||||
actual_length = to_idx - from_idx
|
||||
expected_length = frames_per_episode[episode_idx]
|
||||
assert actual_length == expected_length, (
|
||||
f"Episode {episode_idx} length {actual_length} != expected {expected_length}"
|
||||
)
|
||||
|
||||
cumulative_check = to_idx
|
||||
|
||||
# Final check: last episode should end at total frames
|
||||
expected_total_frames = sum(frames_per_episode)
|
||||
assert cumulative_check == expected_total_frames, (
|
||||
f"Final frame count {cumulative_check} != expected {expected_total_frames}"
|
||||
)
|
||||
|
||||
|
||||
def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that statistics are properly computed and stored for all features."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]},
|
||||
"action": {"dtype": "float32", "shape": (1,), "names": ["force"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Create controlled data to verify statistics
|
||||
num_episodes = 2
|
||||
frames_per_episode = [10, 10]
|
||||
|
||||
# Use deterministic data for predictable statistics
|
||||
torch.manual_seed(42)
|
||||
for episode_idx in range(num_episodes):
|
||||
for frame_idx in range(frames_per_episode[episode_idx]):
|
||||
state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32)
|
||||
action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32)
|
||||
dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that statistics exist for all features
|
||||
assert loaded_dataset.meta.stats is not None, "No statistics found"
|
||||
|
||||
for feature_name in features.keys():
|
||||
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
|
||||
|
||||
feature_stats = loaded_dataset.meta.stats[feature_name]
|
||||
expected_stats = ["min", "max", "mean", "std", "count"]
|
||||
|
||||
for stat_key in expected_stats:
|
||||
assert stat_key in feature_stats, f"Missing '{stat_key}' statistic for '{feature_name}'"
|
||||
|
||||
stat_value = feature_stats[stat_key]
|
||||
# Basic sanity checks
|
||||
if stat_key == "count":
|
||||
assert stat_value == sum(frames_per_episode), f"Wrong count for '{feature_name}'"
|
||||
elif stat_key in ["min", "max", "mean", "std"]:
|
||||
# Check that statistics are reasonable (not NaN, proper shapes)
|
||||
if hasattr(stat_value, "shape"):
|
||||
expected_shape = features[feature_name]["shape"]
|
||||
assert stat_value.shape == expected_shape or len(stat_value) == expected_shape[0], (
|
||||
f"Wrong shape for {stat_key} of '{feature_name}'"
|
||||
)
|
||||
# Check no NaN values
|
||||
if hasattr(stat_value, "__iter__"):
|
||||
assert not any(np.isnan(v) for v in stat_value), f"NaN in {stat_key} for '{feature_name}'"
|
||||
else:
|
||||
assert not np.isnan(stat_value), f"NaN in {stat_key} for '{feature_name}'"
|
||||
|
||||
|
||||
def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test frame indices and episode transitions at episode boundaries."""
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
num_episodes = 3
|
||||
frames_per_episode = [7, 12, 5]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for frame_idx in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Test episode boundaries
|
||||
cumulative = 0
|
||||
for ep_idx, ep_length in enumerate(frames_per_episode):
|
||||
if ep_idx > 0:
|
||||
# Check last frame of previous episode
|
||||
prev_frame = loaded_dataset[cumulative - 1]
|
||||
assert prev_frame["episode_index"].item() == ep_idx - 1
|
||||
|
||||
# Check first frame of current episode
|
||||
if cumulative < len(loaded_dataset):
|
||||
curr_frame = loaded_dataset[cumulative]
|
||||
assert curr_frame["episode_index"].item() == ep_idx
|
||||
|
||||
# Check frame_index within episode
|
||||
for i in range(ep_length):
|
||||
if cumulative + i < len(loaded_dataset):
|
||||
frame = loaded_dataset[cumulative + i]
|
||||
assert frame["frame_index"].item() == i, f"Frame {cumulative + i} has wrong frame_index"
|
||||
assert frame["episode_index"].item() == ep_idx, (
|
||||
f"Frame {cumulative + i} has wrong episode_index"
|
||||
)
|
||||
|
||||
cumulative += ep_length
|
||||
|
||||
|
||||
def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that tasks are properly indexed and retrievable."""
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Use multiple tasks, including repeated ones
|
||||
tasks = ["pick", "place", "pick", "navigate", "place"]
|
||||
unique_tasks = list(set(tasks)) # ["pick", "place", "navigate"]
|
||||
frames_per_episode = [5, 8, 3, 10, 6]
|
||||
|
||||
for episode_idx, task in enumerate(tasks):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": task})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that all unique tasks are in the tasks metadata
|
||||
stored_tasks = set(loaded_dataset.meta.tasks.index)
|
||||
assert stored_tasks == set(unique_tasks), f"Stored tasks {stored_tasks} != expected {set(unique_tasks)}"
|
||||
|
||||
# Check that task indices are consistent
|
||||
cumulative = 0
|
||||
for episode_idx, expected_task in enumerate(tasks):
|
||||
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
|
||||
assert episode_metadata["tasks"] == [expected_task]
|
||||
|
||||
# Check frames in this episode have correct task
|
||||
for i in range(frames_per_episode[episode_idx]):
|
||||
frame = loaded_dataset[cumulative + i]
|
||||
assert frame["task"] == expected_task, f"Frame {cumulative + i} has wrong task"
|
||||
|
||||
# Check task_index consistency
|
||||
expected_task_index = loaded_dataset.meta.get_task_index(expected_task)
|
||||
assert frame["task_index"].item() == expected_task_index
|
||||
|
||||
cumulative += frames_per_episode[episode_idx]
|
||||
|
||||
# Check total number of tasks
|
||||
assert loaded_dataset.meta.total_tasks == len(unique_tasks)
|
||||
|
||||
@@ -1,391 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import safe_shard
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}"
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
)
|
||||
|
||||
streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size))
|
||||
|
||||
key_checks = []
|
||||
for _ in range(ds_num_frames):
|
||||
streaming_frame = next(streaming_ds)
|
||||
frame_idx = streaming_frame["index"]
|
||||
target_frame = ds[frame_idx]
|
||||
|
||||
for key in streaming_frame:
|
||||
left = streaming_frame[key]
|
||||
right = target_frame[key]
|
||||
|
||||
if isinstance(left, str):
|
||||
check = left == right
|
||||
|
||||
elif isinstance(left, torch.Tensor):
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
elif isinstance(left, float):
|
||||
check = left == right.item() # right is a torch.Tensor
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shuffle",
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}"
|
||||
|
||||
lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shuffle",
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||
|
||||
lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
data_files_size_in_mb=data_file_size_mb,
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_deltas, action_deltas",
|
||||
[
|
||||
([-1, -0.5, -0.20, 0], [0, 1, 2, 3]),
|
||||
([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
([-2, -1, -0.5, 0], [0, 1, 2, 3]),
|
||||
([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
],
|
||||
)
|
||||
def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas):
|
||||
ds_num_frames = 500
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
|
||||
seed = 42
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||
camera_key = "phone"
|
||||
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
streaming_ds = iter(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(ds_num_frames):
|
||||
streaming_frame = next(streaming_ds)
|
||||
frame_idx = streaming_frame["index"]
|
||||
target_frame = ds[frame_idx]
|
||||
|
||||
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
||||
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
||||
)
|
||||
|
||||
key_checks = []
|
||||
for key in streaming_frame:
|
||||
left = streaming_frame[key]
|
||||
right = target_frame[key]
|
||||
|
||||
if isinstance(left, str):
|
||||
check = left == right
|
||||
|
||||
elif isinstance(left, torch.Tensor):
|
||||
if (
|
||||
key not in ds.meta.camera_keys
|
||||
and "is_pad" not in key
|
||||
and f"{key}_is_pad" in streaming_frame
|
||||
):
|
||||
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
||||
left = left[~streaming_frame[f"{key}_is_pad"]]
|
||||
right = right[~target_frame[f"{key}_is_pad"]]
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_deltas, action_deltas",
|
||||
[
|
||||
([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]),
|
||||
([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]),
|
||||
([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
],
|
||||
)
|
||||
def test_frames_with_delta_consistency_with_shards(
|
||||
tmp_path, lerobot_dataset_factory, state_deltas, action_deltas
|
||||
):
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
data_file_size_mb = 0.001
|
||||
chunks_size = 1
|
||||
|
||||
seed = 42
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||
camera_key = "phone"
|
||||
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
delta_timestamps=delta_timestamps,
|
||||
data_files_size_in_mb=data_file_size_mb,
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
delta_timestamps=delta_timestamps,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
iter(streaming_ds)
|
||||
|
||||
num_shards = 4
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
streaming_ds = iter(streaming_ds)
|
||||
|
||||
for i in range(ds_num_frames):
|
||||
streaming_frame = next(streaming_ds)
|
||||
frame_idx = streaming_frame["index"]
|
||||
target_frame = ds[frame_idx]
|
||||
|
||||
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
||||
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
||||
)
|
||||
|
||||
key_checks = []
|
||||
for key in streaming_frame:
|
||||
left = streaming_frame[key]
|
||||
right = target_frame[key]
|
||||
|
||||
if isinstance(left, str):
|
||||
check = left == right
|
||||
|
||||
elif isinstance(left, torch.Tensor):
|
||||
if (
|
||||
key not in ds.meta.camera_keys
|
||||
and "is_pad" not in key
|
||||
and f"{key}_is_pad" in streaming_frame
|
||||
):
|
||||
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
||||
left = left[~streaming_frame[f"{key}_is_pad"]]
|
||||
right = right[~target_frame[f"{key}_is_pad"]]
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
elif isinstance(left, float):
|
||||
check = left == right.item() # right is a torch.Tensor
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _install_reachy2_sdk_stub():
|
||||
sdk = types.ModuleType("reachy2_sdk")
|
||||
sdk.__path__ = []
|
||||
sdk.ReachySDK = MagicMock(name="ReachySDK")
|
||||
|
||||
media = types.ModuleType("reachy2_sdk.media")
|
||||
media.__path__ = []
|
||||
camera = types.ModuleType("reachy2_sdk.media.camera")
|
||||
camera.CameraView = MagicMock(name="CameraView")
|
||||
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
|
||||
camera_manager.CameraManager = MagicMock(name="CameraManager")
|
||||
|
||||
sdk.media = media
|
||||
media.camera = camera
|
||||
media.camera_manager = camera_manager
|
||||
|
||||
# Register in sys.modules
|
||||
sys.modules.setdefault("reachy2_sdk", sdk)
|
||||
sys.modules.setdefault("reachy2_sdk.media", media)
|
||||
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
|
||||
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
_install_reachy2_sdk_stub()
|
||||
@@ -1,326 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.robots.reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
Reachy2RobotConfig,
|
||||
)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_JOINTS = {
|
||||
**REACHY2_NECK_JOINTS,
|
||||
**REACHY2_ANTENNAS_JOINTS,
|
||||
**REACHY2_R_ARM_JOINTS,
|
||||
**REACHY2_L_ARM_JOINTS,
|
||||
}
|
||||
|
||||
PARAMS = [
|
||||
{}, # default config
|
||||
{"with_mobile_base": False},
|
||||
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
|
||||
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
|
||||
{"use_external_commands": True, "disable_torque_on_disconnect": True},
|
||||
{"use_external_commands": True, "with_mobile_base": False, "with_neck": False},
|
||||
{"disable_torque_on_disconnect": False},
|
||||
{"max_relative_target": 5},
|
||||
{"with_right_teleop_camera": False},
|
||||
{"with_left_teleop_camera": False, "with_right_teleop_camera": False},
|
||||
{"with_left_teleop_camera": False, "with_torso_camera": True},
|
||||
]
|
||||
|
||||
|
||||
def _make_reachy2_sdk_mock():
|
||||
class JointSpy:
|
||||
__slots__ = (
|
||||
"present_position",
|
||||
"_goal_position",
|
||||
"_on_set",
|
||||
)
|
||||
|
||||
def __init__(self, present_position=0.0, on_set=None):
|
||||
self.present_position = present_position
|
||||
self._goal_position = present_position
|
||||
self._on_set = on_set
|
||||
|
||||
@property
|
||||
def goal_position(self):
|
||||
return self._goal_position
|
||||
|
||||
@goal_position.setter
|
||||
def goal_position(self, v):
|
||||
self._goal_position = v
|
||||
if self._on_set:
|
||||
self._on_set()
|
||||
|
||||
r = MagicMock(name="ReachySDKMock")
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _connect():
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
r.is_connected.return_value = False
|
||||
|
||||
# Global counter of goal_position sets
|
||||
r._goal_position_set_total = 0
|
||||
|
||||
def _on_any_goal_set():
|
||||
r._goal_position_set_total += 1
|
||||
|
||||
# Mock joints with some dummy positions
|
||||
joints = {
|
||||
k: JointSpy(
|
||||
present_position=float(i),
|
||||
on_set=_on_any_goal_set,
|
||||
)
|
||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||
}
|
||||
r.joints = joints
|
||||
|
||||
# Mock mobile base with some dummy odometry
|
||||
r.mobile_base = MagicMock()
|
||||
r.mobile_base.odometry = {
|
||||
"x": 0.1,
|
||||
"y": -0.2,
|
||||
"theta": 21.3,
|
||||
"vx": 0.001,
|
||||
"vy": 0.002,
|
||||
"vtheta": 0.0,
|
||||
}
|
||||
|
||||
r.connect = MagicMock(side_effect=_connect)
|
||||
r.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
# Mock methods
|
||||
r.turn_on = MagicMock()
|
||||
r.reset_default_limits = MagicMock()
|
||||
r.send_goal_positions = MagicMock()
|
||||
r.turn_off_smoothly = MagicMock()
|
||||
r.mobile_base.set_goal_speed = MagicMock()
|
||||
r.mobile_base.send_speed_command = MagicMock()
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cfg = args[0] if args else kwargs.get("config")
|
||||
name = getattr(cfg, "name", kwargs.get("name", "cam"))
|
||||
image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam"))
|
||||
width = getattr(cfg, "width", kwargs.get("width", 640))
|
||||
height = getattr(cfg, "height", kwargs.get("height", 480))
|
||||
|
||||
cam = MagicMock(name=f"Reachy2CameraMock:{name}")
|
||||
cam.name = name
|
||||
cam.image_type = image_type
|
||||
cam.width = width
|
||||
cam.height = height
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
|
||||
def reachy2(request):
|
||||
with (
|
||||
patch(
|
||||
"lerobot.robots.reachy2.robot_reachy2.ReachySDK",
|
||||
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
|
||||
),
|
||||
patch(
|
||||
"lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera",
|
||||
side_effect=_make_reachy2_camera_mock,
|
||||
),
|
||||
):
|
||||
overrides = request.param
|
||||
cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides)
|
||||
robot = Reachy2Robot(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(reachy2):
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.connect()
|
||||
assert reachy2.is_connected
|
||||
|
||||
reachy2.reachy.turn_on.assert_called_once()
|
||||
reachy2.reachy.reset_default_limits.assert_called_once()
|
||||
|
||||
reachy2.disconnect()
|
||||
assert not reachy2.is_connected
|
||||
|
||||
if reachy2.config.disable_torque_on_disconnect:
|
||||
reachy2.reachy.turn_off_smoothly.assert_called_once()
|
||||
else:
|
||||
reachy2.reachy.turn_off_smoothly.assert_not_called()
|
||||
reachy2.reachy.disconnect.assert_called_once()
|
||||
|
||||
|
||||
def test_get_joints_dict(reachy2):
|
||||
reachy2.connect()
|
||||
|
||||
if reachy2.config.with_neck:
|
||||
assert "neck_yaw.pos" in reachy2.joints_dict
|
||||
assert "neck_pitch.pos" in reachy2.joints_dict
|
||||
assert "neck_roll.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "neck_yaw.pos" not in reachy2.joints_dict
|
||||
assert "neck_pitch.pos" not in reachy2.joints_dict
|
||||
assert "neck_roll.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_antennas:
|
||||
assert "l_antenna.pos" in reachy2.joints_dict
|
||||
assert "r_antenna.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "l_antenna.pos" not in reachy2.joints_dict
|
||||
assert "r_antenna.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_r_arm:
|
||||
assert "r_shoulder_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_shoulder_roll.pos" in reachy2.joints_dict
|
||||
assert "r_elbow_yaw.pos" in reachy2.joints_dict
|
||||
assert "r_elbow_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_roll.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_yaw.pos" in reachy2.joints_dict
|
||||
assert "r_gripper.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "r_shoulder_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_shoulder_roll.pos" not in reachy2.joints_dict
|
||||
assert "r_elbow_yaw.pos" not in reachy2.joints_dict
|
||||
assert "r_elbow_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_roll.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_yaw.pos" not in reachy2.joints_dict
|
||||
assert "r_gripper.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_l_arm:
|
||||
assert "l_shoulder_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_shoulder_roll.pos" in reachy2.joints_dict
|
||||
assert "l_elbow_yaw.pos" in reachy2.joints_dict
|
||||
assert "l_elbow_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_roll.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_yaw.pos" in reachy2.joints_dict
|
||||
assert "l_gripper.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "l_shoulder_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_shoulder_roll.pos" not in reachy2.joints_dict
|
||||
assert "l_elbow_yaw.pos" not in reachy2.joints_dict
|
||||
assert "l_elbow_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_roll.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_yaw.pos" not in reachy2.joints_dict
|
||||
assert "l_gripper.pos" not in reachy2.joints_dict
|
||||
|
||||
|
||||
def test_get_observation(reachy2):
|
||||
reachy2.connect()
|
||||
obs = reachy2.get_observation()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
expected_keys.update(reachy2.cameras.keys())
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
if reachy2.config.with_left_teleop_camera:
|
||||
assert obs["teleop_left"].shape == (
|
||||
reachy2.config.cameras["teleop_left"].height,
|
||||
reachy2.config.cameras["teleop_left"].width,
|
||||
3,
|
||||
)
|
||||
if reachy2.config.with_right_teleop_camera:
|
||||
assert obs["teleop_right"].shape == (
|
||||
reachy2.config.cameras["teleop_right"].height,
|
||||
reachy2.config.cameras["teleop_right"].width,
|
||||
3,
|
||||
)
|
||||
if reachy2.config.with_torso_camera:
|
||||
assert obs["torso_rgb"].shape == (
|
||||
reachy2.config.cameras["torso_rgb"].height,
|
||||
reachy2.config.cameras["torso_rgb"].width,
|
||||
3,
|
||||
)
|
||||
|
||||
|
||||
def test_send_action(reachy2):
|
||||
reachy2.connect()
|
||||
|
||||
action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)}
|
||||
if reachy2.config.with_mobile_base:
|
||||
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
|
||||
|
||||
previous_present_position = {
|
||||
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
|
||||
}
|
||||
returned = reachy2.send_action(action)
|
||||
|
||||
if reachy2.config.max_relative_target is None:
|
||||
assert returned == action
|
||||
|
||||
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
expected_pos = action[motor]
|
||||
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.max_relative_target is None:
|
||||
assert real_pos == expected_pos
|
||||
else:
|
||||
assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min(
|
||||
abs(expected_pos - real_pos), reachy2.config.max_relative_target
|
||||
)
|
||||
|
||||
if reachy2.config.with_mobile_base:
|
||||
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
|
||||
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
|
||||
|
||||
if reachy2.config.use_external_commands:
|
||||
reachy2.reachy.send_goal_positions.assert_not_called()
|
||||
if reachy2.config.with_mobile_base:
|
||||
reachy2.reachy.mobile_base.send_speed_command.assert_not_called()
|
||||
else:
|
||||
reachy2.reachy.send_goal_positions.assert_called_once()
|
||||
if reachy2.config.with_mobile_base:
|
||||
reachy2.reachy.mobile_base.send_speed_command.assert_called_once()
|
||||
|
||||
|
||||
def test_no_part_declared():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2RobotConfig(
|
||||
ip_address="192.168.0.200",
|
||||
with_mobile_base=False,
|
||||
with_l_arm=False,
|
||||
with_r_arm=False,
|
||||
with_neck=False,
|
||||
with_antennas=False,
|
||||
)
|
||||
@@ -1,150 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.teleoperators.reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Teleoperator,
|
||||
Reachy2TeleoperatorConfig,
|
||||
)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_JOINTS = {
|
||||
**REACHY2_NECK_JOINTS,
|
||||
**REACHY2_ANTENNAS_JOINTS,
|
||||
**REACHY2_R_ARM_JOINTS,
|
||||
**REACHY2_L_ARM_JOINTS,
|
||||
}
|
||||
|
||||
PARAMS = [
|
||||
{}, # default config
|
||||
{"with_mobile_base": False},
|
||||
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
|
||||
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
|
||||
{"with_mobile_base": False, "with_neck": False},
|
||||
{"use_present_position": True},
|
||||
]
|
||||
|
||||
|
||||
def _make_reachy2_sdk_mock():
|
||||
r = MagicMock(name="ReachySDKMock")
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _connect():
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
r.is_connected.return_value = False
|
||||
|
||||
# Mock joints with some dummy positions
|
||||
joints = {
|
||||
k: MagicMock(
|
||||
present_position=float(i),
|
||||
goal_position=float(i) + 0.5,
|
||||
)
|
||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||
}
|
||||
r.joints = joints
|
||||
|
||||
# Mock mobile base with some dummy odometry
|
||||
r.mobile_base = MagicMock()
|
||||
r.mobile_base.last_cmd_vel = {
|
||||
"vx": -0.2,
|
||||
"vy": 0.2,
|
||||
"vtheta": 11.0,
|
||||
}
|
||||
r.mobile_base.odometry = {
|
||||
"x": 1.0,
|
||||
"y": 2.0,
|
||||
"theta": 20.0,
|
||||
"vx": 0.1,
|
||||
"vy": -0.1,
|
||||
"vtheta": 8.0,
|
||||
}
|
||||
|
||||
r.connect = MagicMock(side_effect=_connect)
|
||||
r.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
|
||||
def reachy2(request):
|
||||
with (
|
||||
patch(
|
||||
"lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK",
|
||||
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
|
||||
),
|
||||
):
|
||||
overrides = request.param
|
||||
cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides)
|
||||
robot = Reachy2Teleoperator(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(reachy2):
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.connect()
|
||||
assert reachy2.is_connected
|
||||
|
||||
reachy2.disconnect()
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.reachy.disconnect.assert_called_once()
|
||||
|
||||
|
||||
def test_get_action(reachy2):
|
||||
reachy2.connect()
|
||||
action = reachy2.get_action()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
assert set(action.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
if reachy2.config.use_present_position:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
else:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
if reachy2.config.use_present_position:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
else:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
|
||||
|
||||
|
||||
def test_no_part_declared():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2TeleoperatorConfig(
|
||||
ip_address="192.168.0.200",
|
||||
with_mobile_base=False,
|
||||
with_l_arm=False,
|
||||
with_r_arm=False,
|
||||
with_neck=False,
|
||||
with_antennas=False,
|
||||
)
|
||||
Reference in New Issue
Block a user