mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
354 Commits
pr-1484
...
feat/add-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac1c2454c5 | ||
|
|
2e3c116fad | ||
|
|
65c174e9f8 | ||
|
|
005d4bb011 | ||
|
|
779d38fff0 | ||
|
|
c0ffb92735 | ||
|
|
baa9b95b97 | ||
|
|
75ce54e212 | ||
|
|
05a2316a63 | ||
|
|
2437decd3f | ||
|
|
2d2f5d3d60 | ||
|
|
2d608f086a | ||
|
|
1c0ac8e341 | ||
|
|
c4c0105a47 | ||
|
|
1b878c9155 | ||
|
|
724874e063 | ||
|
|
91b110d806 | ||
|
|
519b76110e | ||
|
|
e925ef3f18 | ||
|
|
fbf5f04545 | ||
|
|
9fdec23cee | ||
|
|
d9af2f1b89 | ||
|
|
d2645cb19f | ||
|
|
57f7c8b03e | ||
|
|
e9c795e479 | ||
|
|
abe51eeba3 | ||
|
|
c9cff132c3 | ||
|
|
30c161006d | ||
|
|
0136912fa4 | ||
|
|
ce2b9724bf | ||
|
|
cf86b9300d | ||
|
|
67d6bfee78 | ||
|
|
039de254ea | ||
|
|
a3feadbbfb | ||
|
|
25e22ea3ba | ||
|
|
5e27248bba | ||
|
|
a5e0aae13a | ||
|
|
7f7b45cfbb | ||
|
|
28857dccb1 | ||
|
|
a4d46d4adb | ||
|
|
043b720505 | ||
|
|
d985f4b1db | ||
|
|
ab53de989a | ||
|
|
a56cf87f42 | ||
|
|
12d1629aae | ||
|
|
63e2a2e129 | ||
|
|
2a46f3a53f | ||
|
|
171c355858 | ||
|
|
9ad19d4e81 | ||
|
|
e171fa788a | ||
|
|
b1386fd79e | ||
|
|
b47620cd59 | ||
|
|
a32d988536 | ||
|
|
9571a713df | ||
|
|
b418409b24 | ||
|
|
0a6b3992ee | ||
|
|
e6d19116c4 | ||
|
|
92ea7fc0fb | ||
|
|
46cd157c55 | ||
|
|
52028f5201 | ||
|
|
f5b1ef0045 | ||
|
|
81a4deadc3 | ||
|
|
fef83ce349 | ||
|
|
eb3986e131 | ||
|
|
d45226ad06 | ||
|
|
fe43f93553 | ||
|
|
40e0a311b5 | ||
|
|
13677cb720 | ||
|
|
247d493d06 | ||
|
|
2f00475fc6 | ||
|
|
4687296d93 | ||
|
|
5c2f8ccd14 | ||
|
|
d25e3bd989 | ||
|
|
adcb07bf62 | ||
|
|
67e3383ffc | ||
|
|
ac5a9b90c7 | ||
|
|
f35d24a9c3 | ||
|
|
fbdefb2e3e | ||
|
|
0e39d0f6e6 | ||
|
|
b8eecba63d | ||
|
|
7308aa57a2 | ||
|
|
1fd3b2e2db | ||
|
|
69e48bbe19 | ||
|
|
0db1a67eaf | ||
|
|
ccb8468e9b | ||
|
|
f6198d20c6 | ||
|
|
78e29f4f20 | ||
|
|
fb4bfaf029 | ||
|
|
809a9c6de0 | ||
|
|
f4c11593d4 | ||
|
|
71e6520cd1 | ||
|
|
a5f15db057 | ||
|
|
edec51988d | ||
|
|
ddca6765b8 | ||
|
|
cedaa83bce | ||
|
|
4bb965c283 | ||
|
|
4feaef3436 | ||
|
|
e9aac40ba8 | ||
|
|
386ad61007 | ||
|
|
cac4289619 | ||
|
|
0bda18eab5 | ||
|
|
8ab2227148 | ||
|
|
9dab08dfbc | ||
|
|
05dfa26c54 | ||
|
|
edbba48e81 | ||
|
|
10119c1a59 | ||
|
|
c7ef189da0 | ||
|
|
51efe6dfee | ||
|
|
b0592d9bc8 | ||
|
|
363fe64ff9 | ||
|
|
bbcb12e919 | ||
|
|
3e87b09d34 | ||
|
|
81de27dc9a | ||
|
|
eb94a5f03f | ||
|
|
742708942c | ||
|
|
5a2f9b6589 | ||
|
|
06f0c579b7 | ||
|
|
7890767d34 | ||
|
|
d322cb0220 | ||
|
|
f011173ff6 | ||
|
|
20129cd4c2 | ||
|
|
307823bc8d | ||
|
|
64303781c2 | ||
|
|
dd3e305164 | ||
|
|
cb9cac6a1b | ||
|
|
95f9b45418 | ||
|
|
f9db727647 | ||
|
|
230c7fdfab | ||
|
|
320f7e8450 | ||
|
|
08fbbb318f | ||
|
|
8b98399206 | ||
|
|
237b14a6ec | ||
|
|
2e705ff554 | ||
|
|
d72a3f9c32 | ||
|
|
73ac4f38b2 | ||
|
|
a0e69dd708 | ||
|
|
b207babd9e | ||
|
|
293870d0f6 | ||
|
|
87a8cb6d89 | ||
|
|
69dc3f5c9c | ||
|
|
e4d4754600 | ||
|
|
2e528a8b12 | ||
|
|
b7a9b0689a | ||
|
|
b6b9635be6 | ||
|
|
21b1026872 | ||
|
|
8c3eab32b0 | ||
|
|
29633865c7 | ||
|
|
702749b7d3 | ||
|
|
bf1c737858 | ||
|
|
d07c7347f8 | ||
|
|
57e5e4cc07 | ||
|
|
2743c29a96 | ||
|
|
2bb73ac431 | ||
|
|
9afc4b771c | ||
|
|
f71e224023 | ||
|
|
889de7c415 | ||
|
|
3539251b18 | ||
|
|
1f210bc8a3 | ||
|
|
d70bc4bde9 | ||
|
|
bdbca09cb2 | ||
|
|
e0b292ab51 | ||
|
|
f960f4d8d4 | ||
|
|
9e57ec7837 | ||
|
|
0a7f51f0da | ||
|
|
4ca92a28e9 | ||
|
|
0464dc91b3 | ||
|
|
d32daebf75 | ||
|
|
27cb0c40bd | ||
|
|
12abc9ca86 | ||
|
|
4005065223 | ||
|
|
443fed216c | ||
|
|
42a87e7211 | ||
|
|
034171a89a | ||
|
|
782dff1163 | ||
|
|
8924ccbbab | ||
|
|
792c3d961d | ||
|
|
e998dddcfa | ||
|
|
99c0938b42 | ||
|
|
716029b1e3 | ||
|
|
3848a8f9aa | ||
|
|
f7672e14c7 | ||
|
|
e393af2d88 | ||
|
|
0dcb2caba8 | ||
|
|
4679725957 | ||
|
|
57319062aa | ||
|
|
078f59bfd1 | ||
|
|
36fcea2002 | ||
|
|
2971bdfed5 | ||
|
|
28cd3a6f3a | ||
|
|
c0570b3003 | ||
|
|
eeb8490016 | ||
|
|
854b78975a | ||
|
|
e55d2ffe50 | ||
|
|
1ebd81552c | ||
|
|
65569ba90e | ||
|
|
79293800f1 | ||
|
|
bc765f9e95 | ||
|
|
201311503f | ||
|
|
8cc0232e73 | ||
|
|
6bfcc18e73 | ||
|
|
e096754d14 | ||
|
|
02803f545d | ||
|
|
8503e8e166 | ||
|
|
d6007c6e7d | ||
|
|
50963fcf13 | ||
|
|
051a52a4ce | ||
|
|
2292b514aa | ||
|
|
1f1a01a798 | ||
|
|
faa476f0d2 | ||
|
|
5130b69ece | ||
|
|
aed85241b7 | ||
|
|
21c3ac42ee | ||
|
|
2d3a5fb2be | ||
|
|
a631e4c11c | ||
|
|
222d6f104e | ||
|
|
7a3b424cd3 | ||
|
|
af295fadb5 | ||
|
|
9644e2b086 | ||
|
|
6ccf083127 | ||
|
|
bb774e7acd | ||
|
|
dcbbeab80b | ||
|
|
b71ac34214 | ||
|
|
c237d1379e | ||
|
|
cf963eb1b0 | ||
|
|
4293b6a4fb | ||
|
|
7a75bb9f61 | ||
|
|
0c1d4cb323 | ||
|
|
c6212d585d | ||
|
|
7c8ab8e2d6 | ||
|
|
1de75c46c0 | ||
|
|
4ad109cff8 | ||
|
|
8994252019 | ||
|
|
9832daf08d | ||
|
|
39d8f45810 | ||
|
|
30fcd3d417 | ||
|
|
039b437ef0 | ||
|
|
7582a0a2b0 | ||
|
|
25388d0947 | ||
|
|
7152bc8aa7 | ||
|
|
5b46dc0b6a | ||
|
|
4273f1f384 | ||
|
|
97194bf7f3 | ||
|
|
0ac026b521 | ||
|
|
ff7cfdaf40 | ||
|
|
57c97762e1 | ||
|
|
a38bb15e79 | ||
|
|
3ceaee999d | ||
|
|
d485dc1313 | ||
|
|
329d103453 | ||
|
|
9f46a3d8f9 | ||
|
|
c9ca9e4316 | ||
|
|
5a57e6f4a7 | ||
|
|
a2f5c34625 | ||
|
|
1f1e1bcfe8 | ||
|
|
e047074825 | ||
|
|
c2e761437d | ||
|
|
fedac994c3 | ||
|
|
7d558d058e | ||
|
|
1d3e1cbdbd | ||
|
|
0ccc957d5c | ||
|
|
a4d487bc1d | ||
|
|
8ca03a7255 | ||
|
|
f2ed2bfb2f | ||
|
|
40675ec76c | ||
|
|
9e34c1d731 | ||
|
|
857f335be9 | ||
|
|
fc4a95f187 | ||
|
|
4fe1880887 | ||
|
|
6fa859fa19 | ||
|
|
2abfa5838d | ||
|
|
3d119c0ccb | ||
|
|
a32081757d | ||
|
|
56c04ffc53 | ||
|
|
715d4557af | ||
|
|
6541982dff | ||
|
|
43bc9404bb | ||
|
|
375499c323 | ||
|
|
17a4447cef | ||
|
|
287dc13d96 | ||
|
|
02a1cf6a4e | ||
|
|
34cd1e47bf | ||
|
|
74d56834af | ||
|
|
dd80dbb4cd | ||
|
|
bc020ee0a4 | ||
|
|
a15767aff1 | ||
|
|
9af0a9bf37 | ||
|
|
e2c8bc6948 | ||
|
|
2c68c6ca40 | ||
|
|
dd1f33e5ed | ||
|
|
2c1bb766ff | ||
|
|
c1c71fb994 | ||
|
|
2d56f35071 | ||
|
|
64ce2669ca | ||
|
|
f527adf7a9 | ||
|
|
6a77189f50 | ||
|
|
e4a6d035f9 | ||
|
|
794f6e00fc | ||
|
|
97494c6a39 | ||
|
|
9358d334c7 | ||
|
|
c85a9253e7 | ||
|
|
8d659a6aa9 | ||
|
|
f6a2396484 | ||
|
|
7a7af82e35 | ||
|
|
7f23972f3f | ||
|
|
3362b665e6 | ||
|
|
eeeccdba53 | ||
|
|
bd5b181dfd | ||
|
|
858678786a | ||
|
|
0f972661e1 | ||
|
|
2e9b144c56 | ||
|
|
fa8ba9e4e2 | ||
|
|
2037cc0219 | ||
|
|
5006da72ff | ||
|
|
ad0bacbfe4 | ||
|
|
e33ca2c980 | ||
|
|
f0505e81cc | ||
|
|
1f7ddc1d76 | ||
|
|
ce63cfdb25 | ||
|
|
d6f1359e69 | ||
|
|
2357d4aceb | ||
|
|
d6ccdc222c | ||
|
|
9bd0788131 | ||
|
|
1ae62c28f7 | ||
|
|
baf6e66c3d | ||
|
|
a065bd61ae | ||
|
|
5dc3c74e64 | ||
|
|
4214b01703 | ||
|
|
b974e5541f | ||
|
|
fd64dc84ae | ||
|
|
06988b2135 | ||
|
|
7ed7570b17 | ||
|
|
e2d13ba7e4 | ||
|
|
f6c1049474 | ||
|
|
2b24feb604 | ||
|
|
a13e49073c | ||
|
|
2c7e0f17b6 | ||
|
|
418866007e | ||
|
|
5ab418dbeb | ||
|
|
95f61ee9d4 | ||
|
|
ac89c8d226 | ||
|
|
d75d904e43 | ||
|
|
ea4d8d990c | ||
|
|
c93cbb8311 | ||
|
|
c0137e89b9 | ||
|
|
3111ba78ad | ||
|
|
3d3a176940 | ||
|
|
212c6095a2 | ||
|
|
48469ec674 | ||
|
|
c7dfd32b43 | ||
|
|
731fb6ebaf | ||
|
|
13e124302f | ||
|
|
59bdd29106 | ||
|
|
124829104b | ||
|
|
21cd2940a9 |
@@ -37,7 +37,7 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.33.1
|
||||
rev: v1.34.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
@@ -48,7 +48,7 @@ repos:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.13
|
||||
rev: v0.12.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -62,12 +62,12 @@ repos:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.9.0
|
||||
rev: v1.11.0
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
rev: 1.8.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-c", "pyproject.toml"]
|
||||
|
||||
23
README.md
23
README.md
@@ -22,6 +22,29 @@
|
||||
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
Build Your Own HopeJR Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img
|
||||
src="media/hope_jr/hopejr.png?raw=true"
|
||||
alt="HopeJR robot"
|
||||
title="HopeJR robot"
|
||||
style="width: 60%;"
|
||||
/>
|
||||
|
||||
<p><strong>Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!</strong></p>
|
||||
<p>Control it with exoskeletons and gloves for precise hand movements.</p>
|
||||
<p>Perfect for advanced manipulation tasks! 🤖</p>
|
||||
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
See the full HopeJR tutorial here.</a></p>
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
Build Your Own SO-101 Robot!</a></p>
|
||||
|
||||
@@ -17,12 +17,16 @@
|
||||
title: Train a Robot with RL
|
||||
- local: hilserl_sim
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: hope_jr
|
||||
title: Hope Jr
|
||||
- local: so101
|
||||
title: SO-101
|
||||
- local: so100
|
||||
|
||||
272
docs/source/async.mdx
Normal file
272
docs/source/async.mdx
Normal file
@@ -0,0 +1,272 @@
|
||||
# Asynchronous Inference
|
||||
|
||||
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
|
||||
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
|
||||
**Try async inference with all the policies** supported by LeRobot!
|
||||
|
||||
**What you'll learn:**
|
||||
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
|
||||
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
|
||||
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
|
||||
|
||||
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
|
||||
|
||||
|
||||
In a nutshell: with *async inference*, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
|
||||
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
|
||||
|
||||
---
|
||||
## Getting started with async inference
|
||||
|
||||
You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
|
||||
|
||||
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
|
||||
|
||||
```shell
|
||||
pip install -e ".[async]"
|
||||
```
|
||||
|
||||
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
In summary, you need to specify instructions for:
|
||||
- `SERVER`: the address and port of the policy server
|
||||
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
|
||||
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
|
||||
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
|
||||
|
||||
Importantly,
|
||||
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
|
||||
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
|
||||
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
|
||||
|
||||
Done! You should see your robot moving around by now 😉
|
||||
---
|
||||
|
||||
## Async vs. synchronous inference
|
||||
|
||||
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in *idle frames*, frames where the robot awaits idle the policy's output: a new action chunk.
|
||||
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
|
||||
With robotics models increasing in size, this problem risks becoming only more severe.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png" width="80%"></img>
|
||||
</p>
|
||||
<p align="center"><i>Synchronous inference</i> makes the robot idle while the policy is computing the next chunk of actions.</p>
|
||||
|
||||
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
|
||||
Crucially, with async inference, the next action chunk is computed *before* the current one is exhausted, resulting in no idleness.
|
||||
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png" width="80%"></img>
|
||||
</p>
|
||||
<p align="center"><i>Asynchronous inference</i> results in no idleness because the next chunk is computed before the current chunk is exhausted.</p>
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Start the Policy Server
|
||||
|
||||
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
|
||||
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
|
||||
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
|
||||
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.scripts.server.policy_server \
|
||||
--host="localhost" \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
```python
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
)
|
||||
serve(config)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
|
||||
|
||||
---
|
||||
|
||||
## Launch the Robot Client
|
||||
|
||||
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
|
||||
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
|
||||
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.scripts.server.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
# these cameras must match the ones expected by the policy
|
||||
# check the config.json on the Hub for the policy you are using
|
||||
camera_cfg = {
|
||||
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="follower_so100",
|
||||
cameras=camera_cfg
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="fracapuano/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Specify the task
|
||||
task = "Don't do anything, stay still"
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The following two parameters are key in every setup:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hyperparameter</th>
|
||||
<th>Default</th>
|
||||
<th>What it does</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><code>actions_per_chunk</code></td>
|
||||
<td>50</td>
|
||||
<td>How many actions the policy outputs at once. Typical values: 10-50.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><code>chunk_size_threshold</code></td>
|
||||
<td>0.7</td>
|
||||
<td>When the queue is ≤ 50% full, the client sends a fresh observation. Value in [0, 1].</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Tip>
|
||||
Different values of `actions_per_chunk` and `chunk_size_threshold` do result in different behaviours.
|
||||
</Tip>
|
||||
|
||||
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
|
||||
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
|
||||
|
||||
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
|
||||
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
|
||||
|
||||
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
|
||||
|
||||
### Tuning async inference for your setup
|
||||
|
||||
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png" width="80%"></img>
|
||||
</p>
|
||||
<p align="center"><i>The action queue size is plotted at runtime when the `--debug-visualize-queue-size` flag is passed, for various levels of `chunk_size_threshold` (`g` in the SmolVLA paper).</i></p>
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
|
||||
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
|
||||
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
|
||||
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
1
docs/source/hope_jr.mdx
Symbolic link
1
docs/source/hope_jr.mdx
Symbolic link
@@ -0,0 +1 @@
|
||||
../../src/lerobot/robots/hope_jr/hope_jr.mdx
|
||||
@@ -282,6 +282,12 @@ Your dataset will be automatically tagged with `LeRobot` for the community to fi
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
|
||||
|
||||
You can also push your local dataset to the Hub manually, running:
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
```
|
||||
|
||||
|
||||
#### Record function
|
||||
|
||||
The `record` function provides a suite of tools for capturing and managing data during robot operation:
|
||||
|
||||
BIN
media/hope_jr/hopejr.png
Normal file
BIN
media/hope_jr/hopejr.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 72 KiB |
@@ -46,7 +46,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"cmake>=3.29.0.1",
|
||||
"datasets>=2.19.0",
|
||||
"datasets>=2.19.0,<=3.6.0",
|
||||
"deepdiff>=7.0.1",
|
||||
"diffusers>=0.27.2",
|
||||
"draccus==0.10.0",
|
||||
@@ -79,13 +79,14 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
|
||||
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
|
||||
dora = [
|
||||
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
|
||||
]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||
gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"]
|
||||
hopejr = ["feetech-servo-sdk>=1.0.0", "pygame>=2.5.1"]
|
||||
kinematics = ["placo>=0.9.6"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
@@ -104,6 +105,7 @@ hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio
|
||||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
async = ["grpcio==1.71.0", "matplotlib>=3.10.3"]
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
@@ -114,7 +116,7 @@ packages = [
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py", "*.part", "*.stl"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
@@ -131,7 +133,7 @@ exclude_dirs = [
|
||||
"src/lerobot/policies/pi0/conversion_scripts",
|
||||
"src/lerobot/scripts/push_dataset_to_hub.py",
|
||||
]
|
||||
skips = ["B101", "B311", "B404", "B603"]
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
[tool.typos]
|
||||
default.extend-ignore-re = [
|
||||
@@ -146,6 +148,12 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
]
|
||||
|
||||
[tool.typos.files]
|
||||
extend-exclude = [
|
||||
"*.stl",
|
||||
"*.part",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
175
src/lerobot/bi_teleoperate.py
Normal file
175
src/lerobot/bi_teleoperate.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
|
||||
from lerobot.robots.so101_follower_torque.config_so101_follower_t import SO101FollowerTConfig
|
||||
from lerobot.robots.so101_follower_torque.so101_follower_t import SO101FollowerT
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
|
||||
FRQ = 100
|
||||
PRINT_HZ = 10
|
||||
RERUN_HZ = 100
|
||||
ESC_CLR_EOL = "\x1b[K"
|
||||
CURSOR_UP = "\x1b[F"
|
||||
|
||||
follower_cfg = SO101FollowerTConfig(
|
||||
port="/dev/tty.usbmodem58760432961",
|
||||
id="follower_arm_torque",
|
||||
)
|
||||
|
||||
leader_cfg = SO101FollowerTConfig(
|
||||
port="/dev/tty.usbmodem58760432571",
|
||||
id="leader_arm_torque",
|
||||
)
|
||||
|
||||
follower = SO101FollowerT(follower_cfg)
|
||||
leader = SO101FollowerT(leader_cfg)
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
_init_rerun("bilateral_teleoperation")
|
||||
|
||||
print("Starting 4-channel bilateral teleoperation")
|
||||
first_print = True
|
||||
loop_count = 0
|
||||
tic_prev = time.perf_counter()
|
||||
|
||||
while True:
|
||||
tic = time.perf_counter()
|
||||
|
||||
obs_l, obs_f = leader.get_observation(), follower.get_observation()
|
||||
|
||||
dt = tic - tic_prev
|
||||
tic_prev = tic
|
||||
if dt <= 0.0:
|
||||
dt = 0.01 # avoid div-by-zero
|
||||
|
||||
tau_cmd_f, tau_cmd_l = [], []
|
||||
debug_info_f, debug_info_l = {}, {}
|
||||
|
||||
pos_f = {j: obs_f[f"{j}.pos"] for j in follower.bus.motors}
|
||||
vel_f = {j: obs_f[f"{j}.vel"] for j in follower.bus.motors}
|
||||
tau_reaction_f = {j: obs_f[f"{j}.effort"] for j in follower.bus.motors}
|
||||
|
||||
pos_l = {j: obs_l[f"{j}.pos"] for j in leader.bus.motors}
|
||||
vel_l = {j: obs_l[f"{j}.vel"] for j in leader.bus.motors}
|
||||
tau_reaction_l = {j: obs_l[f"{j}.effort"] for j in leader.bus.motors}
|
||||
|
||||
# Joint-specific control gains
|
||||
kp_gains = follower.kp_gains
|
||||
kd_gains = follower.kd_gains
|
||||
kf_gains = follower.kf_gains
|
||||
|
||||
# Compute torque commands
|
||||
tau_cmd_f = [
|
||||
kp_gains[j] * (pos_l[j] - pos_f[j]) # Position tracking
|
||||
+ kd_gains[j] * (vel_l[j] - vel_f[j]) # Velocity damping
|
||||
+ kf_gains[j] * (-tau_reaction_l[j] - tau_reaction_f[j]) # Force reflection
|
||||
for j in follower.bus.motors
|
||||
]
|
||||
|
||||
tau_cmd_l = [
|
||||
kp_gains[j] * (pos_f[j] - pos_l[j]) # Position tracking
|
||||
+ kd_gains[j] * (vel_f[j] - vel_l[j]) # Velocity damping
|
||||
+ kf_gains[j] * (-tau_reaction_f[j] - tau_reaction_l[j]) # Force reflection
|
||||
for j in leader.bus.motors
|
||||
]
|
||||
|
||||
# Store debug info
|
||||
for i, j in enumerate(follower.bus.motors):
|
||||
debug_info_f[j] = {
|
||||
"τ_reaction": tau_reaction_f[j],
|
||||
"τ_ref": tau_cmd_f[i],
|
||||
"θ_err": pos_l[j] - pos_f[j],
|
||||
"ω_err": vel_l[j] - vel_f[j],
|
||||
"τ_err": -tau_reaction_l[j] - tau_reaction_f[j],
|
||||
}
|
||||
debug_info_l[j] = {
|
||||
"τ_reaction": tau_reaction_l[j],
|
||||
"τ_ref": tau_cmd_l[i],
|
||||
"θ_err": pos_f[j] - pos_l[j],
|
||||
"ω_err": vel_f[j] - vel_l[j],
|
||||
"τ_err": -tau_reaction_f[j] - tau_reaction_l[j],
|
||||
}
|
||||
|
||||
# Send torques to both arms
|
||||
follower.send_action({f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(follower.bus.motors)})
|
||||
leader.send_action({f"{m}.effort": tau_cmd_l[i] for i, m in enumerate(leader.bus.motors)})
|
||||
|
||||
observation = {
|
||||
"follower_joint_angles": pos_f, # θ_f: current angles
|
||||
"follower_angular_velocities": vel_f, # ω_f: current velocities
|
||||
"follower_external_torques": tau_reaction_f, # τ_ext: measured minus deterministic components
|
||||
}
|
||||
|
||||
action = {
|
||||
"leader_target_angles": pos_l, # θ_leader[τ]: absolute target angles
|
||||
"leader_target_velocities": vel_l, # ω_leader[τ]: absolute target velocities
|
||||
"leader_interaction_torques": tau_reaction_l, # τ_leader[τ]: cmd minus deterministic components
|
||||
}
|
||||
|
||||
if loop_count % (FRQ // RERUN_HZ) == 0:
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
loop_count += 1
|
||||
if loop_count % (FRQ // PRINT_HZ) == 0:
|
||||
hz = 1.0 / dt
|
||||
|
||||
lines = [f"Loop {hz:6.1f} Hz Δt {dt * 1e3:5.2f} ms"]
|
||||
lines.append("=" * 106)
|
||||
lines.append("LEADER ARM TORQUE ANALYSIS:")
|
||||
lines.append(f"{'Joint':<13}{'Pos':>8}{'React':>6}{'Cmd':>6}")
|
||||
lines.append(f"{'':13}{'(deg)':>8}{'(Nm)':>6}{'(Nm)':>6}")
|
||||
lines.append("-" * 86)
|
||||
|
||||
for i, j in enumerate(leader.bus.motors):
|
||||
debug_l = debug_info_l[j]
|
||||
|
||||
lines.append(
|
||||
f"{j:<13s}{math.degrees(pos_l[j]):+8.1f}{debug_l['τ_reaction']:+6.2f}{tau_cmd_l[i]:+6.2f}"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
lines.append("FOLLOWER ARM TORQUE ANALYSIS:")
|
||||
lines.append(f"{'Joint':<13}{'Pos':>8}{'React':>6}{'Cmd':>6}")
|
||||
lines.append(f"{'':13}{'(deg)':>8}{'(Nm)':>6}{'(Nm)':>6}")
|
||||
lines.append("-" * 86)
|
||||
|
||||
for i, j in enumerate(follower.bus.motors):
|
||||
debug_f = debug_info_f[j]
|
||||
|
||||
lines.append(
|
||||
f"{j:<13s}{math.degrees(pos_f[j]):+8.1f}{debug_f['τ_reaction']:+6.2f}{tau_cmd_f[i]:+6.2f}"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
lines.append("=" * 86)
|
||||
lines.append("TORQUE COMPONENT EXPLANATIONS:")
|
||||
lines.append("• Pos (joint pos) = Joint position in degrees")
|
||||
lines.append("• React (reaction) = External forces (human interaction, contact)")
|
||||
lines.append("• Meas (measured) = Raw torque from motor current sensor")
|
||||
lines.append("• Cmd (command) = Final torque sent to motor")
|
||||
lines.append("-" * 86)
|
||||
lines.append(
|
||||
"Cmd = Track + Vel + Force + (Added as feedforward in send_action: Grav + Inert + Frict)"
|
||||
)
|
||||
lines.append("React = Meas - Grav - Inert - Frict (external forces)")
|
||||
lines.append("Force = Kf × (reflect_other_robot - React) (telepresence)")
|
||||
lines.append("Frict = b_visc×ω + f_coulomb×sign(ω) (transparency)")
|
||||
lines.append(
|
||||
f"Joint Gains: shoulder_pan Kp={kp_gains['shoulder_pan']:.1f} | shoulder_pan Kd={kd_gains['shoulder_pan']:.1f} | shoulder_pan Kf={kf_gains['shoulder_pan']:.1f}"
|
||||
)
|
||||
lines.append(
|
||||
f"Friction Comp, Viscous: {follower.friction_viscous['shoulder_pan']:.3f} | Coulomb: {follower.friction_coulomb['shoulder_pan']:.3f} (robot-class)"
|
||||
)
|
||||
|
||||
block = "\n".join(lines)
|
||||
if first_print:
|
||||
sys.stdout.write(block + "\n")
|
||||
first_print = False
|
||||
else:
|
||||
sys.stdout.write(CURSOR_UP * len(lines) + ESC_CLR_EOL + block + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
busy_wait(max(0.0, 1.0 / FRQ - (time.perf_counter() - tic)))
|
||||
@@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
lekiwi,
|
||||
make_robot_from_config,
|
||||
@@ -45,6 +46,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
|
||||
@@ -18,12 +18,16 @@ Provides the OpenCVCamera class for capturing frames from cameras using OpenCV.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any, Dict, List
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
@@ -108,7 +112,8 @@ class OpenCVCamera(Camera):
|
||||
self.config = config
|
||||
self.index_or_path = config.index_or_path
|
||||
|
||||
self.fps = config.fps
|
||||
self.wanted_fps = config.fps
|
||||
self.camera_fps = None
|
||||
self.color_mode = config.color_mode
|
||||
self.warmup_s = config.warmup_s
|
||||
|
||||
@@ -196,10 +201,9 @@ class OpenCVCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
# We don't set the FPS. We GET the actual (max) FPS from the camera.
|
||||
self.camera_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
logger.info(f"{self} is running at its default/max FPS: {self.camera_fps:.2f}")
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
@@ -312,19 +316,23 @@ class OpenCVCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
# Start the background capture thread if it's not running
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
# Perform an initial blocking read to populate the first frame
|
||||
ret, frame = self.videocapture.read()
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} failed to read initial frame.")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
self.latest_frame = self._postprocess_image(frame)
|
||||
self._start_read_thread()
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Read thread started but no frame is available for {self}.")
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return processed_frame
|
||||
return frame.copy()
|
||||
|
||||
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
@@ -382,16 +390,23 @@ class OpenCVCamera(Camera):
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
ret, frame = self.videocapture.read()
|
||||
if not ret or frame is None:
|
||||
logger.warning(f"Failed to read frame in background for {self}.")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
processed_frame = self._postprocess_image(frame)
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_frame = processed_frame
|
||||
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if not self.is_connected:
|
||||
break
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
|
||||
@@ -60,6 +60,8 @@ def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return cv2.CAP_AVFOUNDATION
|
||||
else:
|
||||
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return cv2.CAP_ANY
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,21 +37,6 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
# Multi-dataset support
|
||||
sampling_weights: str | None = None
|
||||
max_action_dim: int | None = None
|
||||
max_state_dim: int | None = None
|
||||
max_num_images: int | None = None
|
||||
max_image_dim: int | None = None
|
||||
train_on_all_features: bool = False
|
||||
features_version: int = 0
|
||||
discard_first_n_frames: int = 0
|
||||
min_fps: int = 1
|
||||
max_fps: int = 100
|
||||
discard_first_idle_frames: bool = False
|
||||
motion_threshold: float = 5e-2
|
||||
motion_window_size: int = 10
|
||||
motion_buffer: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
@@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
||||
# HACK: Parse the original config to get the config subclass, so that we can
|
||||
# apply cli overrides.
|
||||
# This is very ugly, ideally we'd like to be able to do that natively with draccus
|
||||
# something like --policy.path (in addition to --policy.type)
|
||||
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_overrides)
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config.pop("type")
|
||||
with tempfile.NamedTemporaryFile("w+") as f:
|
||||
json.dump(config, f)
|
||||
config_file = f.name
|
||||
f.flush()
|
||||
|
||||
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||
|
||||
@@ -22,16 +22,15 @@ OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
OBS_IMAGE_2 = "observation.image2"
|
||||
OBS_IMAGE_3 = "observation.image3"
|
||||
OBS_IMAGE_4 = "observation.image4"
|
||||
REWARD = "next.reward"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TASK = "task"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
|
||||
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
|
||||
|
||||
|
||||
def pad_tensor(
|
||||
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
|
||||
) -> torch.Tensor:
|
||||
is_numpy = isinstance(tensor, np.ndarray)
|
||||
if is_numpy:
|
||||
tensor = torch.tensor(tensor)
|
||||
pad = max_size - tensor.shape[pad_dim]
|
||||
if pad > 0:
|
||||
pad_sizes = (0, pad) # pad right
|
||||
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
|
||||
return tensor.numpy() if is_numpy else tensor
|
||||
|
||||
|
||||
def pad_list_of_tensors(
|
||||
tensors: List[torch.Tensor], pad_dim: int = -1, pad_value: float = 0.0
|
||||
) -> List[torch.Tensor]:
|
||||
max_size = max([v.shape[pad_dim] for v in tensors])
|
||||
return [pad_tensor(tensor, max_size, pad_dim=pad_dim, pad_value=pad_value) for tensor in tensors]
|
||||
|
||||
|
||||
def multidataset_collate_fn(
|
||||
batch: List[Dict[str, torch.Tensor]],
|
||||
pad_dim: int = -1,
|
||||
pad_value: float = 0.0,
|
||||
keys_to_max_dim: dict = {},
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Custom collate function to pad tensors with multiple dimensions.
|
||||
|
||||
Args:
|
||||
batch (List[Dict[str, torch.Tensor]]): List of dataset samples (each sample is a dictionary).
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Batch with padded tensors.
|
||||
"""
|
||||
batch_keys = batch[0].keys()
|
||||
collated_batch = [{} for _ in range(len(batch))]
|
||||
# FIXME(mshukor): pad to max shape per feature type
|
||||
for key in batch_keys:
|
||||
values = [sample[key] for sample in batch]
|
||||
if (
|
||||
key in keys_to_max_dim
|
||||
and isinstance(values[0], torch.Tensor)
|
||||
and is_batch_need_padding(values, pad_dim=pad_dim)
|
||||
and keys_to_max_dim[key] is not None
|
||||
):
|
||||
max_size = keys_to_max_dim[key]
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = pad_tensor(
|
||||
batch[i][key], max_size, pad_dim=pad_dim, pad_value=pad_value
|
||||
)
|
||||
else:
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = batch[i][key]
|
||||
collated_batch = default_collate(collated_batch)
|
||||
|
||||
return collated_batch
|
||||
@@ -125,30 +125,9 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
# Filter out stats that don't have required keys
|
||||
valid_stats = []
|
||||
for s in stats_ft_list:
|
||||
if all(key in s for key in ["mean", "std", "count", "min", "max"]):
|
||||
valid_stats.append(s)
|
||||
else:
|
||||
# If count is missing, add it with a default value
|
||||
if "count" not in s:
|
||||
s["count"] = np.array([1]) # Default count
|
||||
valid_stats.append(s)
|
||||
|
||||
if not valid_stats:
|
||||
# If no valid stats, return empty stats
|
||||
return {
|
||||
"min": np.array([0]),
|
||||
"max": np.array([0]),
|
||||
"mean": np.array([0]),
|
||||
"std": np.array([0]),
|
||||
"count": np.array([0]),
|
||||
}
|
||||
|
||||
means = np.stack([s["mean"] for s in valid_stats])
|
||||
variances = np.stack([s["std"] ** 2 for s in valid_stats])
|
||||
counts = np.stack([s["count"] for s in valid_stats])
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
total_count = counts.sum(axis=0)
|
||||
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
@@ -165,8 +144,8 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in valid_stats]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in valid_stats]), axis=0),
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"count": total_count,
|
||||
|
||||
@@ -32,8 +32,6 @@ IMAGENET_STATS = {
|
||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||
}
|
||||
|
||||
from lerobot.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
@@ -83,77 +81,35 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
if "," in cfg.dataset.repo_id:
|
||||
repo_id = cfg.dataset.repo_id.split(",")
|
||||
repo_id = [r for r in repo_id if r]
|
||||
else:
|
||||
repo_id = cfg.dataset.repo_id
|
||||
sampling_weights = cfg.dataset.sampling_weights.split(",") if cfg.dataset.sampling_weights else None
|
||||
feature_keys_mapping = FEATURE_KEYS_MAPPING
|
||||
if isinstance(repo_id, str):
|
||||
revision = getattr(cfg.dataset, "revision", None)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
revision=revision,
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=getattr(cfg.dataset, "root", None),
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=revision,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
download_videos=True,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
max_action_dim=cfg.dataset.max_action_dim,
|
||||
max_state_dim=cfg.dataset.max_state_dim,
|
||||
max_num_images=cfg.dataset.max_num_images,
|
||||
max_image_dim=cfg.dataset.max_image_dim,
|
||||
)
|
||||
else:
|
||||
delta_timestamps = {}
|
||||
episodes = {}
|
||||
for i in range(len(repo_id)):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
repo_id[i],
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
) # FIXME(mshukor): ?
|
||||
delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes)
|
||||
# training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None)
|
||||
# FIXME: (jadechoghari): check support for training features
|
||||
training_features = None
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_id,
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
episodes=episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
# delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
download_videos=True,
|
||||
sampling_weights=sampling_weights,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
max_action_dim=cfg.policy.max_action_dim,
|
||||
max_state_dim=cfg.policy.max_state_dim,
|
||||
max_num_images=cfg.dataset.max_num_images,
|
||||
max_image_dim=cfg.dataset.max_image_dim,
|
||||
train_on_all_features=cfg.dataset.train_on_all_features,
|
||||
training_features=training_features,
|
||||
discard_first_n_frames=cfg.dataset.discard_first_n_frames,
|
||||
min_fps=cfg.dataset.min_fps,
|
||||
max_fps=cfg.dataset.max_fps,
|
||||
discard_first_idle_frames=cfg.dataset.discard_first_idle_frames,
|
||||
motion_threshold=cfg.dataset.motion_threshold,
|
||||
motion_window_size=cfg.dataset.motion_window_size,
|
||||
motion_buffer=cfg.dataset.motion_buffer,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -31,16 +30,8 @@ from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.constants import (
|
||||
ACTION,
|
||||
HF_LEROBOT_HOME,
|
||||
OBS_ENV_STATE,
|
||||
OBS_STATE,
|
||||
)
|
||||
from lerobot.datasets.compute_stats import ( # aggregate_stats_per_robot_type,
|
||||
aggregate_stats,
|
||||
compute_episode_stats,
|
||||
)
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
@@ -50,6 +41,7 @@ from lerobot.datasets.utils import (
|
||||
_validate_feature_names,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
@@ -66,34 +58,12 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
map_dict_keys,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
# keep_datasets_with_the_same_features_per_robot_type,
|
||||
# map_dict_pad_keys,
|
||||
# keep_datasets_with_valid_fps,
|
||||
# find_start_of_motion,
|
||||
)
|
||||
|
||||
# mustafa stuff here
|
||||
from lerobot.datasets.utils_must import (
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGE_2,
|
||||
OBS_IMAGE_3,
|
||||
ROBOT_TYPE_KEYS_MAPPING,
|
||||
TASKS_KEYS_MAPPING,
|
||||
aggregate_stats_per_robot_type,
|
||||
create_padded_features,
|
||||
find_start_of_motion,
|
||||
keep_datasets_with_the_same_features_per_robot_type,
|
||||
keep_datasets_with_valid_fps,
|
||||
map_dict_keys,
|
||||
pad_tensor,
|
||||
reshape_features_to_max_dim,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
@@ -104,15 +74,6 @@ from lerobot.datasets.video_utils import (
|
||||
)
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
|
||||
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
|
||||
for t in range(len(velocities) - window_size):
|
||||
window_mean = velocities[t : t + window_size].mean()
|
||||
if window_mean > threshold:
|
||||
return max(0, t - motion_buffer) # include slight context before motion
|
||||
return 0
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -120,13 +81,10 @@ class LeRobotDatasetMetadata:
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
feature_keys_mapping: dict[str, str] | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.local_files_only = local_files_only
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
@@ -141,14 +99,6 @@ class LeRobotDatasetMetadata:
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
# added by mshukor
|
||||
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||
self.inverse_feature_keys_mapping = (
|
||||
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||
)
|
||||
self.info["features"] = map_dict_keys(
|
||||
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
|
||||
)
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
@@ -227,15 +177,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
# changed
|
||||
keys = []
|
||||
for key, ft in self.features.items():
|
||||
key_ = (
|
||||
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
|
||||
)
|
||||
if ft["dtype"] == "video":
|
||||
keys.append(key_)
|
||||
return keys
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
@@ -400,18 +342,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
# new thing by M
|
||||
feature_keys_mapping: dict[str, str] | None = None,
|
||||
max_action_dim: int = None,
|
||||
max_state_dim: int = None,
|
||||
max_num_images: int = None,
|
||||
max_image_dim: int = None,
|
||||
training_features: list | None = None,
|
||||
discard_first_n_frames: int = 0,
|
||||
discard_first_idle_frames: bool = False,
|
||||
motion_threshold: float = 5e-2,
|
||||
motion_window_size: int = 10,
|
||||
motion_buffer: int = 3,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -525,34 +455,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
|
||||
# by mshukor
|
||||
self.training_features = training_features
|
||||
self.discard_first_n_frames = discard_first_n_frames
|
||||
self.discard_first_idle_frames = discard_first_idle_frames
|
||||
self.motion_threshold = motion_threshold
|
||||
self.motion_window_size = motion_window_size
|
||||
self.motion_buffer = motion_buffer
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# more mshukor
|
||||
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||
self.inverse_feature_keys_mapping = (
|
||||
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||
)
|
||||
|
||||
# Load metadata
|
||||
# TODO: change
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id,
|
||||
self.root,
|
||||
self.revision,
|
||||
force_cache_sync=force_cache_sync,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
@@ -571,74 +482,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# mustafa code
|
||||
if self.discard_first_n_frames > 0:
|
||||
print("Discarding first n frames:", self.discard_first_n_frames)
|
||||
self.subset_frame_ids = []
|
||||
for ep_idx in range(self.num_episodes):
|
||||
from_ = self.episode_data_index["from"][ep_idx]
|
||||
to_ = self.episode_data_index["to"][ep_idx]
|
||||
# TODO implement advanced strategy
|
||||
self.subset_frame_ids += [
|
||||
frame_idx for frame_idx in range(from_ + int(self.fps * self.discard_first_n_frames), to_)
|
||||
]
|
||||
elif self.discard_first_idle_frames:
|
||||
print(
|
||||
f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}"
|
||||
)
|
||||
self.robot_states = torch.stack(self.hf_dataset[OBS_STATE]).numpy() # shape: [T, D]
|
||||
self.subset_frame_ids = []
|
||||
for ep_idx in range(self.num_episodes):
|
||||
from_ = self.episode_data_index["from"][ep_idx]
|
||||
to_ = self.episode_data_index["to"][ep_idx]
|
||||
ep_states = self.robot_states[from_:to_]
|
||||
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
|
||||
velocities = np.concatenate([[0.0], velocities])
|
||||
start_idx = find_start_of_motion(
|
||||
velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer
|
||||
)
|
||||
self.subset_frame_ids += list(range(from_ + start_idx, to_))
|
||||
|
||||
# Check timestamps
|
||||
# commented TODO: check why
|
||||
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
# TODO: check why commented
|
||||
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Mustafa
|
||||
self.meta.info["features"] = map_dict_keys(
|
||||
self.meta.info["features"],
|
||||
feature_keys_mapping=self.feature_keys_mapping,
|
||||
training_features=self.training_features,
|
||||
)
|
||||
self.keys_to_max_dim = {
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
}
|
||||
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
self.meta.stats = map_dict_keys(
|
||||
self.meta.stats,
|
||||
feature_keys_mapping=self.feature_keys_mapping,
|
||||
training_features=self.training_features,
|
||||
)
|
||||
self.robot_type = self.meta.info.get("robot_type", "")
|
||||
# Override tasks
|
||||
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
|
||||
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
@@ -793,7 +647,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
# FIXME(mshukor): what if we train on multiple datasets with different features
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
@@ -817,21 +670,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return query_timestamps
|
||||
|
||||
# TODO: changed by mustafa
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
queries = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if (
|
||||
key not in self.meta.video_keys
|
||||
and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys
|
||||
):
|
||||
key_ = (
|
||||
self.inverse_feature_keys_mapping.get(key, key)
|
||||
if self.inverse_feature_keys_mapping
|
||||
else key
|
||||
)
|
||||
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
|
||||
return queries
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
@@ -855,12 +699,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
# changed by mshukor
|
||||
def __getitem__(self, idx) -> dict:
|
||||
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
|
||||
idx = self.subset_frame_ids[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
@@ -877,27 +717,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
try:
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
except:
|
||||
print(self.meta.tasks, task_idx, self.repo_id)
|
||||
if "robot_type" not in item:
|
||||
item["robot_type"] = self.robot_type
|
||||
item = map_dict_keys(
|
||||
item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
|
||||
)
|
||||
# Add padded features
|
||||
# item = self._add_padded_features(item, self.training_features)
|
||||
if self.image_transforms is not None:
|
||||
for cam in item:
|
||||
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
# Map pad keys
|
||||
# print(item.keys(), "before")
|
||||
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||
# print(item.keys())
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1157,7 +985,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
@@ -1178,106 +1005,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return obj
|
||||
|
||||
|
||||
class MultiLeRobotDatasetMeta:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
keys_to_max_dim: dict[str, int],
|
||||
train_on_all_features: bool = False,
|
||||
):
|
||||
self.repo_ids = repo_ids
|
||||
self.keys_to_max_dim = keys_to_max_dim
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
|
||||
|
||||
# assign robot_type if missing
|
||||
for ds in datasets:
|
||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||
ds.robot_type = ds.meta.info["robot_type"]
|
||||
|
||||
# step 1: compute disabled features
|
||||
self.disabled_features = set()
|
||||
if not self.train_on_all_features:
|
||||
intersection = set(datasets[0].features)
|
||||
for ds in datasets:
|
||||
intersection.intersection_update(ds.features)
|
||||
if not intersection:
|
||||
raise RuntimeError("No common features across datasets.")
|
||||
for repo_id, ds in zip(repo_ids, datasets, strict=False):
|
||||
extra = set(ds.features) - intersection
|
||||
logging.warning(f"Disabling {extra} for repo {repo_id}")
|
||||
self.disabled_features.update(extra)
|
||||
|
||||
# step 2: build union_features excluding disabled
|
||||
self.union_features = {}
|
||||
for ds in datasets:
|
||||
for k, v in ds.features.items():
|
||||
if k not in self.disabled_features:
|
||||
self.union_features[k] = v
|
||||
|
||||
# step 3: reshape feature schema
|
||||
self.features = reshape_features_to_max_dim(
|
||||
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
|
||||
# step 4: aggregate stats
|
||||
self.stats = aggregate_stats_per_robot_type(datasets)
|
||||
for robot_type_, stats_ in self.stats.items():
|
||||
for feat_key, feat_stats in stats_.items():
|
||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||
for k, v in feat_stats.items():
|
||||
pad_value = 0 if k in ["min", "mean"] else 1
|
||||
self.stats[robot_type_][feat_key][k] = pad_tensor(
|
||||
v,
|
||||
max_size=self.keys_to_max_dim.get(feat_key, -1),
|
||||
pad_dim=-1,
|
||||
pad_value=pad_value,
|
||||
)
|
||||
|
||||
# step 5: episodes & tasks
|
||||
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
|
||||
|
||||
class MultiLeRobotDatasetCleaner:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
sampling_weights: list[float],
|
||||
datasets_repo_ids: list[str],
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
):
|
||||
self.original_datasets = datasets
|
||||
self.original_repo_ids = repo_ids
|
||||
self.original_weights = sampling_weights
|
||||
self.original_datasets_repo_ids = datasets_repo_ids
|
||||
|
||||
# step 1: remove datasets with invalid fps
|
||||
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
|
||||
|
||||
# step 2: keep datasets with same features per robot type
|
||||
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
|
||||
valid_fps_datasets
|
||||
)
|
||||
|
||||
self.cleaned_datasets = consistent_datasets
|
||||
self.keep_mask = keep_mask
|
||||
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||
self.cleaned_datasets_repo_ids = [
|
||||
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
|
||||
]
|
||||
|
||||
self.cumulative_sizes = np.array(
|
||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
|
||||
)
|
||||
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
@@ -1294,24 +1021,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
# add
|
||||
sampling_weights: list[float] | None = None,
|
||||
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
|
||||
max_action_dim: int = None,
|
||||
max_state_dim: int = None,
|
||||
max_num_images: int = None,
|
||||
max_image_dim: int = None,
|
||||
train_on_all_features: bool = False,
|
||||
training_features: list | None = None,
|
||||
discard_first_n_frames: int = 0,
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
discard_first_idle_frames: bool = False,
|
||||
motion_threshold: float = 0.05,
|
||||
motion_window_size: int = 10,
|
||||
motion_buffer: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
@@ -1319,89 +1029,46 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
_datasets = []
|
||||
datasets_repo_ids = []
|
||||
self.sampling_weights = []
|
||||
self.training_features = training_features
|
||||
|
||||
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
|
||||
assert len(sampling_weights) == len(repo_ids), (
|
||||
"The number of sampling weights must match the number of datasets. "
|
||||
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
|
||||
)
|
||||
for i, repo_id in enumerate(repo_ids):
|
||||
try:
|
||||
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
_datasets.append(
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes.get(repo_id, None) if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
training_features=training_features,
|
||||
discard_first_n_frames=discard_first_n_frames,
|
||||
discard_first_idle_frames=discard_first_idle_frames,
|
||||
motion_threshold=motion_threshold,
|
||||
motion_window_size=motion_window_size,
|
||||
motion_buffer=motion_buffer,
|
||||
)
|
||||
)
|
||||
datasets_repo_ids.append(repo_id)
|
||||
self.sampling_weights.append(float(sampling_weights[i]))
|
||||
except Exception as e:
|
||||
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
|
||||
print(
|
||||
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
|
||||
)
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
# FIXME(mshukor): apply mapping to unify used keys
|
||||
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = (
|
||||
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||
) # delta_timestamps # FIXME(mshukor): last repo?
|
||||
# In case datasets with the same robot_type have different features
|
||||
cleaner = MultiLeRobotDatasetCleaner(
|
||||
datasets=_datasets,
|
||||
repo_ids=repo_ids,
|
||||
sampling_weights=self.sampling_weights,
|
||||
datasets_repo_ids=datasets_repo_ids,
|
||||
min_fps=min_fps,
|
||||
max_fps=max_fps,
|
||||
)
|
||||
self._datasets = cleaner.cleaned_datasets
|
||||
self.sampling_weights = cleaner.cleaned_weights
|
||||
self.repo_ids = cleaner.cleaned_repo_ids
|
||||
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
|
||||
self.cumulative_sizes = cleaner.cumulative_sizes
|
||||
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||
# self.meta.info = {
|
||||
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
# }
|
||||
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
||||
self.meta = MultiLeRobotDatasetMeta(
|
||||
datasets=self._datasets,
|
||||
repo_ids=self.repo_ids,
|
||||
keys_to_max_dim={
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
},
|
||||
train_on_all_features=train_on_all_features,
|
||||
)
|
||||
self.disabled_features = self.meta.disabled_features
|
||||
self.stats = self.meta.stats
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
@@ -1489,14 +1156,23 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
|
||||
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||
item = self._datasets[dataset_idx][local_idx]
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
item = create_padded_features(item, self.meta.features)
|
||||
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
||||
for data_key in self.disabled_features:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -858,21 +858,3 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_dict_keys(
|
||||
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
|
||||
) -> dict:
|
||||
"""Maps feature keys from the dataset to the keys used in the model."""
|
||||
if feature_keys_mapping is None:
|
||||
return item
|
||||
features = {}
|
||||
for key in item:
|
||||
if key in feature_keys_mapping:
|
||||
if feature_keys_mapping[key] is not None:
|
||||
if training_features is None or feature_keys_mapping[key] in training_features:
|
||||
features[feature_keys_mapping[key]] = item[key]
|
||||
else:
|
||||
if training_features is None or key in training_features or pad_key in key:
|
||||
features[key] = item[key]
|
||||
return features
|
||||
|
||||
@@ -1,409 +0,0 @@
|
||||
"""
|
||||
Utils function by Mustafa to refactor
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGE_2 = "observation.image2"
|
||||
OBS_IMAGE_3 = "observation.image3"
|
||||
|
||||
|
||||
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
|
||||
"""Reshape features to have a maximum dimension of `max_dim`."""
|
||||
reshaped_features = {}
|
||||
for key in features:
|
||||
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
|
||||
reshaped_features[key] = features[key]
|
||||
shape = list(features[key]["shape"])
|
||||
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
|
||||
shape[-3] = keys_to_max_dim[key]
|
||||
shape[-2] = keys_to_max_dim[key]
|
||||
else:
|
||||
shape[reshape_dim] = keys_to_max_dim[key]
|
||||
reshaped_features[key]["shape"] = tuple(shape)
|
||||
else:
|
||||
reshaped_features[key] = features[key]
|
||||
return reshaped_features
|
||||
|
||||
|
||||
def keep_datasets_with_valid_fps(ls_datasets: list, min_fps: int = 1, max_fps: int = 100) -> list:
|
||||
print(
|
||||
f"Keeping datasets with fps between {min_fps} and {max_fps}. Considering {len(ls_datasets)} datasets."
|
||||
)
|
||||
for ds in ls_datasets:
|
||||
if ds.fps < min_fps or ds.fps > max_fps:
|
||||
print(f"Dataset {ds} has invalid fps: {ds.fps}. Removing it.")
|
||||
ls_datasets.remove(ds)
|
||||
print(f"Keeping {len(ls_datasets)} datasets with valid fps.")
|
||||
return ls_datasets
|
||||
|
||||
|
||||
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
|
||||
"""
|
||||
Filters datasets to only keep those with consistent feature shapes per robot type.
|
||||
|
||||
Args:
|
||||
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
|
||||
and `meta.episodes_stats` dictionary.
|
||||
|
||||
Returns:
|
||||
List: Filtered list of datasets with consistent feature shapes.
|
||||
"""
|
||||
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||
datasets_to_remove = set()
|
||||
|
||||
for robot_type in robot_types:
|
||||
# Collect all stats dicts for this robot type
|
||||
stats_list = [
|
||||
ep_stats
|
||||
for ds in ls_datasets
|
||||
if ds.meta.info["robot_type"] == robot_type
|
||||
for ep_stats in ds.meta.episodes_stats.values()
|
||||
]
|
||||
if not stats_list:
|
||||
continue
|
||||
|
||||
# Determine the most common shape for each key
|
||||
all_keys = {key for stats in stats_list for key in stats}
|
||||
for ds in ls_datasets:
|
||||
if ds.meta.info["robot_type"] != robot_type:
|
||||
continue
|
||||
for key in all_keys:
|
||||
shape_counter = defaultdict(int)
|
||||
|
||||
for stats in stats_list:
|
||||
value = stats.get(key)
|
||||
if (
|
||||
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||
): # FIXME(mshukor): check all stats; min, mean, max
|
||||
shape_counter[value["mean"].shape] += 1
|
||||
if not shape_counter:
|
||||
continue
|
||||
|
||||
# Identify the most frequent shape
|
||||
main_shape = max(shape_counter, key=shape_counter.get)
|
||||
# Flag datasets that don't match the main shape
|
||||
# for ds in ls_datasets:
|
||||
first_ep_stats = next(iter(ds.meta.episodes_stats.values()), None)
|
||||
if not first_ep_stats:
|
||||
continue
|
||||
value = first_ep_stats.get(key)
|
||||
if (
|
||||
value
|
||||
and "mean" in value
|
||||
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||
and value["mean"].shape != main_shape
|
||||
):
|
||||
datasets_to_remove.add(ds)
|
||||
break
|
||||
|
||||
# Filter out inconsistent datasets
|
||||
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
|
||||
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
|
||||
print(
|
||||
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
|
||||
)
|
||||
return filtered_datasets, datasets_maks
|
||||
|
||||
|
||||
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
|
||||
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||
stats = {robot_type: {} for robot_type in robot_types}
|
||||
for robot_type in robot_types:
|
||||
robot_type_datasets = []
|
||||
for ds in ls_datasets:
|
||||
if ds.meta.info["robot_type"] == robot_type:
|
||||
robot_type_datasets.extend(list(ds.meta.episodes_stats.values()))
|
||||
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
|
||||
stat = aggregate_stats(robot_type_datasets)
|
||||
stats[robot_type] = stat
|
||||
return stats
|
||||
|
||||
|
||||
def str_to_torch_dtype(dtype_str):
|
||||
"""Convert a dtype string to a torch dtype."""
|
||||
mapping = {
|
||||
"float32": torch.float32,
|
||||
"int64": torch.int64,
|
||||
"int16": torch.int16,
|
||||
"bool": torch.bool,
|
||||
"video": torch.float32, # Assuming video is stored as uint8 images
|
||||
}
|
||||
return mapping.get(dtype_str, torch.float32) # Default to float32
|
||||
|
||||
|
||||
def create_padded_features(item: dict, features: dict = {}):
|
||||
for key, ft in features.items():
|
||||
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if len(shape) == 3: # images to torch format (C, H, W)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
|
||||
shape = []
|
||||
if key not in item:
|
||||
dtype = str_to_torch_dtype(ft["dtype"])
|
||||
item[key] = torch.zeros(shape, dtype=dtype)
|
||||
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
|
||||
if "image" in key: # FIXME(mshukor): support other observations
|
||||
item[f"{key}_is_pad"] = torch.BoolTensor([False])
|
||||
else:
|
||||
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
|
||||
return item
|
||||
|
||||
|
||||
ROBOT_TYPE_KEYS_MAPPING = {
|
||||
"lerobot/stanford_hydra_dataset": "static_single_arm",
|
||||
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
|
||||
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
|
||||
"lerobot/toto": "static_single_arm",
|
||||
"lerobot/roboturk": "static_single_arm",
|
||||
"lerobot/jaco_play": "static_single_arm",
|
||||
"lerobot/taco_play": "static_single_arm_7statedim",
|
||||
}
|
||||
|
||||
|
||||
def pad_tensor(
|
||||
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
|
||||
) -> torch.Tensor:
|
||||
is_numpy = isinstance(tensor, np.ndarray)
|
||||
if is_numpy:
|
||||
tensor = torch.tensor(tensor)
|
||||
if tensor.ndim == 0:
|
||||
# Scalar — return as-is, no padding needed
|
||||
return tensor
|
||||
pad = max_size - tensor.shape[pad_dim]
|
||||
if pad > 0:
|
||||
pad_sizes = (0, pad) # pad right
|
||||
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
|
||||
return tensor.numpy() if is_numpy else tensor
|
||||
|
||||
|
||||
def map_dict_keys(
|
||||
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
|
||||
) -> dict:
|
||||
"""Maps feature keys from the dataset to the keys used in the model."""
|
||||
if feature_keys_mapping is None:
|
||||
return item
|
||||
features = {}
|
||||
for key in item:
|
||||
if key in feature_keys_mapping:
|
||||
if feature_keys_mapping[key] is not None:
|
||||
if training_features is None or feature_keys_mapping[key] in training_features:
|
||||
features[feature_keys_mapping[key]] = item[key]
|
||||
else:
|
||||
if training_features is None or key in training_features or pad_key in key:
|
||||
features[key] = item[key]
|
||||
|
||||
# breakpoint()
|
||||
return features
|
||||
|
||||
|
||||
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
|
||||
for t in range(len(velocities) - window_size):
|
||||
window_mean = velocities[t : t + window_size].mean()
|
||||
if window_mean > threshold:
|
||||
return max(0, t - motion_buffer) # include slight context before motion
|
||||
return 0
|
||||
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
|
||||
def load_yaml_mapping(name: str) -> dict:
|
||||
"""
|
||||
Loads a YAML mapping from a Hugging Face repo.
|
||||
Example: name='features' → https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/features.yaml
|
||||
"""
|
||||
url = f"https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/{name}.yaml"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status() # raise if the download fails
|
||||
|
||||
return yaml.safe_load(response.text)
|
||||
|
||||
|
||||
# Example usage
|
||||
TASKS_KEYS_MAPPING = load_yaml_mapping("tasks")
|
||||
FEATURE_KEYS_MAPPING = load_yaml_mapping("features")
|
||||
EPISODES_DATASET_MAPPING = {
|
||||
"cadene/droid_1.0.1": list(range(50)),
|
||||
"danaaubakirova/svla_so100_task5_v3": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
],
|
||||
"danaaubakirova/svla_so100_task4_v3": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
],
|
||||
}
|
||||
ACTION = "action"
|
||||
OBS_STATE = "observation.state"
|
||||
TASK = "task"
|
||||
ROBOT = "robot_type"
|
||||
TRAINING_FEATURES = {
|
||||
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
|
||||
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
|
||||
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
|
||||
}
|
||||
|
||||
|
||||
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
|
||||
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
|
||||
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, target_shape: tuple, pad_value: float = 0.0) -> torch.Tensor:
|
||||
"""Pads a tensor to the target shape (right/bottom only)."""
|
||||
pad = []
|
||||
for actual, target in zip(reversed(tensor.shape), reversed(target_shape), strict=False):
|
||||
pad.extend([0, max(target - actual, 0)])
|
||||
return F.pad(tensor, pad, value=pad_value)
|
||||
|
||||
|
||||
def multidataset_collate_fn(
|
||||
batch: List[Dict[str, torch.Tensor]],
|
||||
keys_to_max_dim: Dict[str, tuple] = {},
|
||||
pad_value: float = 0.0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Pads tensors to given target shape (if provided), otherwise uses per-batch max.
|
||||
Supports 1D (e.g. action), 3D (e.g. [C,H,W] images).
|
||||
"""
|
||||
collated_batch = [{} for _ in range(len(batch))]
|
||||
batch_keys = batch[0].keys()
|
||||
|
||||
for key in batch_keys:
|
||||
values = [sample[key] for sample in batch]
|
||||
sample = values[0]
|
||||
|
||||
if not isinstance(sample, torch.Tensor):
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = values[i]
|
||||
continue
|
||||
|
||||
# use user-specified shape if available
|
||||
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
|
||||
target_shape = keys_to_max_dim[key]
|
||||
else:
|
||||
# compute per-batch max shape
|
||||
target_shape = tuple(max(v.shape[i] for v in values) for i in range(sample.ndim))
|
||||
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = pad_tensor_to_shape(values[i], target_shape, pad_value=pad_value)
|
||||
|
||||
return default_collate(collated_batch)
|
||||
401
src/lerobot/motors/calibration_gui.py
Normal file
401
src/lerobot/motors/calibration_gui.py
Normal file
@@ -0,0 +1,401 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
|
||||
|
||||
from lerobot.motors import MotorCalibration, MotorsBus
|
||||
|
||||
BAR_LEN, BAR_THICKNESS = 450, 8
|
||||
HANDLE_R = 10
|
||||
BRACKET_W, BRACKET_H = 6, 14
|
||||
TRI_W, TRI_H = 12, 14
|
||||
|
||||
BTN_W, BTN_H = 60, 22
|
||||
SAVE_W, SAVE_H = 80, 28
|
||||
LOAD_W = 80
|
||||
DD_W, DD_H = 160, 28
|
||||
|
||||
TOP_GAP = 50
|
||||
PADDING_Y, TOP_OFFSET = 70, 60
|
||||
FONT_SIZE, FPS = 20, 60
|
||||
|
||||
BG_COLOR = (30, 30, 30)
|
||||
BAR_RED, BAR_GREEN = (200, 60, 60), (60, 200, 60)
|
||||
HANDLE_COLOR, TEXT_COLOR = (240, 240, 240), (250, 250, 250)
|
||||
TICK_COLOR = (250, 220, 40)
|
||||
BTN_COLOR, BTN_COLOR_HL = (80, 80, 80), (110, 110, 110)
|
||||
DD_COLOR, DD_COLOR_HL = (70, 70, 70), (100, 100, 100)
|
||||
|
||||
|
||||
def dist(a, b):
|
||||
return math.hypot(a[0] - b[0], a[1] - b[1])
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeValues:
|
||||
min_v: int
|
||||
pos_v: int
|
||||
max_v: int
|
||||
|
||||
|
||||
class RangeSlider:
|
||||
"""One motor = one slider row"""
|
||||
|
||||
def __init__(self, motor, idx, res, calibration, present, label_pad, base_y):
|
||||
import pygame
|
||||
|
||||
self.motor = motor
|
||||
self.res = res
|
||||
self.x0 = 40 + label_pad
|
||||
self.x1 = self.x0 + BAR_LEN
|
||||
self.y = base_y + idx * PADDING_Y
|
||||
|
||||
self.min_v = calibration.range_min
|
||||
self.max_v = calibration.range_max
|
||||
self.pos_v = max(self.min_v, min(present, self.max_v))
|
||||
|
||||
self.min_x = self._pos_from_val(self.min_v)
|
||||
self.max_x = self._pos_from_val(self.max_v)
|
||||
self.pos_x = self._pos_from_val(self.pos_v)
|
||||
|
||||
self.min_btn = pygame.Rect(self.x0 - BTN_W - 6, self.y - BTN_H // 2, BTN_W, BTN_H)
|
||||
self.max_btn = pygame.Rect(self.x1 + 6, self.y - BTN_H // 2, BTN_W, BTN_H)
|
||||
|
||||
self.drag_min = self.drag_max = self.drag_pos = False
|
||||
self.tick_val = present
|
||||
self.font = pygame.font.Font(None, FONT_SIZE)
|
||||
|
||||
def _val_from_pos(self, x):
|
||||
return round((x - self.x0) / BAR_LEN * self.res)
|
||||
|
||||
def _pos_from_val(self, v):
|
||||
return self.x0 + (v / self.res) * BAR_LEN
|
||||
|
||||
def set_tick(self, v):
|
||||
self.tick_val = max(0, min(v, self.res))
|
||||
|
||||
def _triangle_hit(self, pos):
|
||||
import pygame
|
||||
|
||||
tri_top = self.y - BAR_THICKNESS // 2 - 2
|
||||
return pygame.Rect(self.pos_x - TRI_W // 2, tri_top - TRI_H, TRI_W, TRI_H).collidepoint(pos)
|
||||
|
||||
def handle_event(self, e):
|
||||
import pygame
|
||||
|
||||
if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
|
||||
if self.min_btn.collidepoint(e.pos):
|
||||
self.min_x, self.min_v = self.pos_x, self.pos_v
|
||||
return
|
||||
if self.max_btn.collidepoint(e.pos):
|
||||
self.max_x, self.max_v = self.pos_x, self.pos_v
|
||||
return
|
||||
if dist(e.pos, (self.min_x, self.y)) <= HANDLE_R:
|
||||
self.drag_min = True
|
||||
elif dist(e.pos, (self.max_x, self.y)) <= HANDLE_R:
|
||||
self.drag_max = True
|
||||
elif self._triangle_hit(e.pos):
|
||||
self.drag_pos = True
|
||||
|
||||
elif e.type == pygame.MOUSEBUTTONUP and e.button == 1:
|
||||
self.drag_min = self.drag_max = self.drag_pos = False
|
||||
|
||||
elif e.type == pygame.MOUSEMOTION:
|
||||
x = e.pos[0]
|
||||
if self.drag_min:
|
||||
self.min_x = max(self.x0, min(x, self.pos_x))
|
||||
elif self.drag_max:
|
||||
self.max_x = min(self.x1, max(x, self.pos_x))
|
||||
elif self.drag_pos:
|
||||
self.pos_x = max(self.min_x, min(x, self.max_x))
|
||||
|
||||
self.min_v = self._val_from_pos(self.min_x)
|
||||
self.max_v = self._val_from_pos(self.max_x)
|
||||
self.pos_v = self._val_from_pos(self.pos_x)
|
||||
|
||||
def _draw_button(self, surf, rect, text):
|
||||
import pygame
|
||||
|
||||
clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR
|
||||
pygame.draw.rect(surf, clr, rect, border_radius=4)
|
||||
t = self.font.render(text, True, TEXT_COLOR)
|
||||
surf.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2))
|
||||
|
||||
def draw(self, surf):
|
||||
import pygame
|
||||
|
||||
# motor name above set-min button (right-aligned)
|
||||
name_surf = self.font.render(self.motor, True, TEXT_COLOR)
|
||||
surf.blit(
|
||||
name_surf,
|
||||
(self.min_btn.right - name_surf.get_width(), self.min_btn.y - name_surf.get_height() - 4),
|
||||
)
|
||||
|
||||
# bar + active section
|
||||
pygame.draw.rect(surf, BAR_RED, (self.x0, self.y - BAR_THICKNESS // 2, BAR_LEN, BAR_THICKNESS))
|
||||
pygame.draw.rect(
|
||||
surf, BAR_GREEN, (self.min_x, self.y - BAR_THICKNESS // 2, self.max_x - self.min_x, BAR_THICKNESS)
|
||||
)
|
||||
|
||||
# tick
|
||||
tick_x = self._pos_from_val(self.tick_val)
|
||||
pygame.draw.line(
|
||||
surf,
|
||||
TICK_COLOR,
|
||||
(tick_x, self.y - BAR_THICKNESS // 2 - 4),
|
||||
(tick_x, self.y + BAR_THICKNESS // 2 + 4),
|
||||
2,
|
||||
)
|
||||
|
||||
# brackets
|
||||
for x, sign in ((self.min_x, +1), (self.max_x, -1)):
|
||||
pygame.draw.line(
|
||||
surf, HANDLE_COLOR, (x, self.y - BRACKET_H // 2), (x, self.y + BRACKET_H // 2), 2
|
||||
)
|
||||
pygame.draw.line(
|
||||
surf,
|
||||
HANDLE_COLOR,
|
||||
(x, self.y - BRACKET_H // 2),
|
||||
(x + sign * BRACKET_W, self.y - BRACKET_H // 2),
|
||||
2,
|
||||
)
|
||||
pygame.draw.line(
|
||||
surf,
|
||||
HANDLE_COLOR,
|
||||
(x, self.y + BRACKET_H // 2),
|
||||
(x + sign * BRACKET_W, self.y + BRACKET_H // 2),
|
||||
2,
|
||||
)
|
||||
|
||||
# triangle ▼
|
||||
tri_top = self.y - BAR_THICKNESS // 2 - 2
|
||||
pygame.draw.polygon(
|
||||
surf,
|
||||
HANDLE_COLOR,
|
||||
[
|
||||
(self.pos_x, tri_top),
|
||||
(self.pos_x - TRI_W // 2, tri_top - TRI_H),
|
||||
(self.pos_x + TRI_W // 2, tri_top - TRI_H),
|
||||
],
|
||||
)
|
||||
|
||||
# numeric labels
|
||||
fh = self.font.get_height()
|
||||
pos_y = tri_top - TRI_H - 4 - fh
|
||||
txts = [
|
||||
(self.min_v, self.min_x, self.y - BRACKET_H // 2 - 4 - fh),
|
||||
(self.max_v, self.max_x, self.y - BRACKET_H // 2 - 4 - fh),
|
||||
(self.pos_v, self.pos_x, pos_y),
|
||||
]
|
||||
for v, x, y in txts:
|
||||
s = self.font.render(str(v), True, TEXT_COLOR)
|
||||
surf.blit(s, (x - s.get_width() // 2, y))
|
||||
|
||||
# buttons
|
||||
self._draw_button(surf, self.min_btn, "set min")
|
||||
self._draw_button(surf, self.max_btn, "set max")
|
||||
|
||||
# external
|
||||
def values(self) -> RangeValues:
|
||||
return RangeValues(self.min_v, self.pos_v, self.max_v)
|
||||
|
||||
|
||||
class RangeFinderGUI:
|
||||
def __init__(self, bus: MotorsBus, groups: dict[str, list[str]] | None = None):
|
||||
import pygame
|
||||
|
||||
self.bus = bus
|
||||
self.groups = groups if groups is not None else {"all": list(bus.motors)}
|
||||
self.group_names = list(groups)
|
||||
self.current_group = self.group_names[0]
|
||||
|
||||
if not bus.is_connected:
|
||||
bus.connect()
|
||||
|
||||
self.calibration = bus.read_calibration()
|
||||
self.res_table = bus.model_resolution_table
|
||||
self.present_cache = {
|
||||
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
|
||||
}
|
||||
|
||||
pygame.init()
|
||||
self.font = pygame.font.Font(None, FONT_SIZE)
|
||||
|
||||
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
|
||||
self.label_pad = label_pad
|
||||
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
|
||||
self.controls_bottom = 10 + SAVE_H
|
||||
self.base_y = self.controls_bottom + TOP_GAP
|
||||
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
|
||||
|
||||
self.screen = pygame.display.set_mode((width, height))
|
||||
pygame.display.set_caption("Motors range finder")
|
||||
|
||||
# ui rects
|
||||
self.save_btn = pygame.Rect(width - SAVE_W - 10, 10, SAVE_W, SAVE_H)
|
||||
self.load_btn = pygame.Rect(self.save_btn.left - LOAD_W - 10, 10, LOAD_W, SAVE_H)
|
||||
self.dd_btn = pygame.Rect(width // 2 - DD_W // 2, 10, DD_W, DD_H)
|
||||
self.dd_open = False # dropdown expanded?
|
||||
|
||||
self.clock = pygame.time.Clock()
|
||||
self._build_sliders()
|
||||
self._adjust_height()
|
||||
|
||||
def _adjust_height(self):
|
||||
import pygame
|
||||
|
||||
motors = self.groups[self.current_group]
|
||||
new_h = self.base_y + PADDING_Y * len(motors) + 40
|
||||
if new_h != self.screen.get_height():
|
||||
w = self.screen.get_width()
|
||||
self.screen = pygame.display.set_mode((w, new_h))
|
||||
|
||||
def _build_sliders(self):
|
||||
self.sliders: list[RangeSlider] = []
|
||||
motors = self.groups[self.current_group]
|
||||
for i, m in enumerate(motors):
|
||||
self.sliders.append(
|
||||
RangeSlider(
|
||||
motor=m,
|
||||
idx=i,
|
||||
res=self.res_table[self.bus.motors[m].model] - 1,
|
||||
calibration=self.calibration[m],
|
||||
present=self.present_cache[m],
|
||||
label_pad=self.label_pad,
|
||||
base_y=self.base_y,
|
||||
)
|
||||
)
|
||||
|
||||
def _draw_dropdown(self):
|
||||
import pygame
|
||||
|
||||
# collapsed box
|
||||
hover = self.dd_btn.collidepoint(pygame.mouse.get_pos())
|
||||
pygame.draw.rect(self.screen, DD_COLOR_HL if hover else DD_COLOR, self.dd_btn, border_radius=6)
|
||||
|
||||
txt = self.font.render(self.current_group, True, TEXT_COLOR)
|
||||
self.screen.blit(
|
||||
txt, (self.dd_btn.centerx - txt.get_width() // 2, self.dd_btn.centery - txt.get_height() // 2)
|
||||
)
|
||||
|
||||
tri_w, tri_h = 12, 6
|
||||
cx = self.dd_btn.right - 14
|
||||
cy = self.dd_btn.centery + 1
|
||||
pygame.draw.polygon(
|
||||
self.screen,
|
||||
TEXT_COLOR,
|
||||
[(cx - tri_w // 2, cy - tri_h // 2), (cx + tri_w // 2, cy - tri_h // 2), (cx, cy + tri_h // 2)],
|
||||
)
|
||||
|
||||
if not self.dd_open:
|
||||
return
|
||||
|
||||
# expanded list
|
||||
for i, name in enumerate(self.group_names):
|
||||
item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H)
|
||||
clr = DD_COLOR_HL if item_rect.collidepoint(pygame.mouse.get_pos()) else DD_COLOR
|
||||
pygame.draw.rect(self.screen, clr, item_rect)
|
||||
t = self.font.render(name, True, TEXT_COLOR)
|
||||
self.screen.blit(
|
||||
t, (item_rect.centerx - t.get_width() // 2, item_rect.centery - t.get_height() // 2)
|
||||
)
|
||||
|
||||
def _handle_dropdown_event(self, e):
|
||||
import pygame
|
||||
|
||||
if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
|
||||
if self.dd_btn.collidepoint(e.pos):
|
||||
self.dd_open = not self.dd_open
|
||||
return True
|
||||
if self.dd_open:
|
||||
for i, name in enumerate(self.group_names):
|
||||
item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H)
|
||||
if item_rect.collidepoint(e.pos):
|
||||
if name != self.current_group:
|
||||
self.current_group = name
|
||||
self._build_sliders()
|
||||
self._adjust_height()
|
||||
self.dd_open = False
|
||||
return True
|
||||
self.dd_open = False
|
||||
return False
|
||||
|
||||
def _save_current(self):
|
||||
for s in self.sliders:
|
||||
self.calibration[s.motor].range_min = s.min_v
|
||||
self.calibration[s.motor].range_max = s.max_v
|
||||
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.write_calibration(self.calibration)
|
||||
|
||||
def _load_current(self):
|
||||
self.calibration = self.bus.read_calibration()
|
||||
for s in self.sliders:
|
||||
s.min_v = self.calibration[s.motor].range_min
|
||||
s.max_v = self.calibration[s.motor].range_max
|
||||
s.min_x = s._pos_from_val(s.min_v)
|
||||
s.max_x = s._pos_from_val(s.max_v)
|
||||
|
||||
def run(self) -> dict[str, MotorCalibration]:
|
||||
import pygame
|
||||
|
||||
while True:
|
||||
for e in pygame.event.get():
|
||||
if e.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
return self.calibration
|
||||
|
||||
if self._handle_dropdown_event(e):
|
||||
continue
|
||||
|
||||
if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
|
||||
if self.save_btn.collidepoint(e.pos):
|
||||
self._save_current()
|
||||
elif self.load_btn.collidepoint(e.pos):
|
||||
self._load_current()
|
||||
|
||||
for s in self.sliders:
|
||||
s.handle_event(e)
|
||||
|
||||
# live goal write while dragging
|
||||
for s in self.sliders:
|
||||
if s.drag_pos:
|
||||
self.bus.write("Goal_Position", s.motor, s.pos_v, normalize=False)
|
||||
|
||||
# tick update
|
||||
for s in self.sliders:
|
||||
pos = self.bus.read("Present_Position", s.motor, normalize=False)
|
||||
s.set_tick(pos)
|
||||
self.present_cache[s.motor] = pos
|
||||
|
||||
# ─ drawing
|
||||
self.screen.fill(BG_COLOR)
|
||||
for s in self.sliders:
|
||||
s.draw(self.screen)
|
||||
|
||||
self._draw_dropdown()
|
||||
|
||||
# load / save buttons
|
||||
for rect, text in ((self.load_btn, "LOAD"), (self.save_btn, "SAVE")):
|
||||
clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR
|
||||
pygame.draw.rect(self.screen, clr, rect, border_radius=6)
|
||||
t = self.font.render(text, True, TEXT_COLOR)
|
||||
self.screen.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2))
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(FPS)
|
||||
@@ -162,11 +162,11 @@ class DynamixelMotorsBus(MotorsBus):
|
||||
|
||||
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
def configure_motors(self, return_delay_time=0) -> None:
|
||||
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
|
||||
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
||||
for motor in self.motors:
|
||||
self.write("Return_Delay_Time", motor, 0)
|
||||
self.write("Return_Delay_Time", motor, return_delay_time)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
@@ -190,13 +190,14 @@ class DynamixelMotorsBus(MotorsBus):
|
||||
|
||||
return calibration
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
for motor, calibration in calibration_dict.items():
|
||||
self.write("Homing_Offset", motor, calibration.homing_offset)
|
||||
self.write("Min_Position_Limit", motor, calibration.range_min)
|
||||
self.write("Max_Position_Limit", motor, calibration.range_max)
|
||||
|
||||
self.calibration = calibration_dict
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -164,8 +164,9 @@ class FeetechMotorsBus(MotorsBus):
|
||||
)
|
||||
|
||||
def _handshake(self) -> None:
|
||||
self._assert_motors_exist()
|
||||
self._assert_same_firmware()
|
||||
# self._assert_motors_exist()
|
||||
# self._assert_same_firmware()
|
||||
return
|
||||
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
if self.protocol_version == 0:
|
||||
@@ -219,94 +220,70 @@ class FeetechMotorsBus(MotorsBus):
|
||||
|
||||
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
def configure_motors(self, return_delay_time=0, maximum_acceleration=254, acceleration=254) -> None:
|
||||
for motor in self.motors:
|
||||
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
|
||||
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
||||
self.write("Return_Delay_Time", motor, 0)
|
||||
# self.write("Return_Delay_Time", motor, 0) # THIS DOES NOT WORK FOR HLS3625
|
||||
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
|
||||
# Note: this address is not in the official STS3215 Memory Table
|
||||
self.write("Maximum_Acceleration", motor, 254)
|
||||
self.write("Acceleration", motor, 254)
|
||||
if self.protocol_version == 0:
|
||||
self.write("Maximum_Acceleration", motor, maximum_acceleration)
|
||||
self.write("Acceleration", motor, acceleration)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
motors_calibration = self.read_calibration()
|
||||
if set(motors_calibration) != set(self.calibration):
|
||||
return False
|
||||
|
||||
same_ranges = all(
|
||||
self.calibration[motor].range_min == cal.range_min
|
||||
and self.calibration[motor].range_max == cal.range_max
|
||||
for motor, cal in motors_calibration.items()
|
||||
)
|
||||
if self.protocol_version == 1:
|
||||
return same_ranges
|
||||
|
||||
same_offsets = all(
|
||||
self.calibration[motor].homing_offset == cal.homing_offset
|
||||
for motor, cal in motors_calibration.items()
|
||||
)
|
||||
return same_ranges and same_offsets
|
||||
# Check if calibration data has been loaded from file
|
||||
return bool(self.calibration)
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
offsets, mins, maxes = {}, {}, {}
|
||||
for motor in self.motors:
|
||||
mins[motor] = self.read("Min_Position_Limit", motor, normalize=False)
|
||||
maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False)
|
||||
offsets[motor] = (
|
||||
self.read("Homing_Offset", motor, normalize=False) if self.protocol_version == 0 else 0
|
||||
)
|
||||
|
||||
# Return empty calibration - we don't read from motors anymore
|
||||
calibration = {}
|
||||
for motor, m in self.motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
homing_offset=0,
|
||||
range_min=0,
|
||||
range_max=4095, # Default max resolution
|
||||
)
|
||||
|
||||
return calibration
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||
for motor, calibration in calibration_dict.items():
|
||||
if self.protocol_version == 0:
|
||||
self.write("Homing_Offset", motor, calibration.homing_offset)
|
||||
self.write("Min_Position_Limit", motor, calibration.range_min)
|
||||
self.write("Max_Position_Limit", motor, calibration.range_max)
|
||||
|
||||
# Only update the in-memory calibration, don't write to motors
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
|
||||
"""
|
||||
On Feetech Motors:
|
||||
Present_Position = Actual_Position - Homing_Offset
|
||||
Calculate homing offsets such that the current position becomes 0 degrees.
|
||||
|
||||
For Feetech motors:
|
||||
- The homing offset is subtracted from the raw position during normalization
|
||||
- So to make current position = 0 degrees, homing_offset = current_raw_position
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
half_turn_homings[motor] = pos - int(max_res / 2)
|
||||
# The homing offset should be the current position
|
||||
# This way, when we normalize: (pos - homing_offset) = 0
|
||||
half_turn_homings[motor] = pos
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 5) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
# self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 5) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor_id, 0, num_retry=num_retry)
|
||||
# addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
# self._write(addr, length, motor_id, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 5) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 1, num_retry=num_retry)
|
||||
# self.write("Lock", motor, 1, num_retry=num_retry)
|
||||
|
||||
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
for id_ in ids_values:
|
||||
|
||||
@@ -151,6 +151,95 @@ SCS_SERIES_CONTROL_TABLE = {
|
||||
"Acceleration_2": (83, 1), # don't know what that is
|
||||
}
|
||||
|
||||
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
|
||||
HLS_SERIES_CONTROL_TABLE = {
|
||||
# Version Information (0-4) - read-only
|
||||
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # (0, 1) read-only
|
||||
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # (1, 1) read-only
|
||||
"End_Type": (2, 1), # read-only - 0 represents little-endian storage
|
||||
"Model_Number": MODEL_NUMBER, # (3, 2) read-only
|
||||
# EPROM configuration (5-39)
|
||||
"ID": (5, 1), # Main ID - unique identifier on bus
|
||||
"Baud_Rate": (6, 1), # 0-7 for different baud rates
|
||||
"Secondary_ID": (7, 1), # Secondary ID for write instructions
|
||||
"Response_Status_Level": (8, 1), # 0: limited response, 1: full response
|
||||
"Min_Position_Limit": (9, 2), # 0-4094 (0.087 degrees per unit)
|
||||
"Max_Position_Limit": (11, 2), # 1-4095 (0.087 degrees per unit)
|
||||
"Max_Temperature_Limit": (13, 1), # 0-100 (°C)
|
||||
"Max_Voltage_Limit": (14, 1), # 0-254 (0.1V per unit)
|
||||
"Min_Voltage_Limit": (15, 1), # 0-254 (0.1V per unit)
|
||||
"Max_Torque_Limit": (16, 2), # 0-1000 (0.1% per unit)
|
||||
"Phase": (18, 1), # Special function byte for motor phase configuration
|
||||
"Unloading_Condition": (19, 1), # Bit flags for protection conditions
|
||||
"LED_Alarm_Condition": (20, 1), # Bit flags for LED alarm conditions
|
||||
"P_Coefficient": (21, 1), # Position ring P proportional coefficient
|
||||
"D_Coefficient": (22, 1), # Position ring D differential coefficient
|
||||
"I_Coefficient": (23, 1), # Position ring I integral coefficient
|
||||
"Minimum_Startup_Force": (24, 1), # 0-254 (0.1% per unit)
|
||||
"Point_Limit_Value": (25, 1), # 0-254 - maximum point value = point_limit * 4
|
||||
"CW_Dead_Zone": (26, 1), # 0-16 (0.087 degrees per unit)
|
||||
"CCW_Dead_Zone": (27, 1), # 0-16 (0.087 degrees per unit)
|
||||
"Protection_Current": (28, 2), # 0-2047 (6.5 mA per unit)
|
||||
"Angle_Resolution": (30, 1), # 1-128 - amplification coefficient
|
||||
"Homing_Offset": (31, 2), # -4095 to 4095 (0.087 degrees per unit)
|
||||
"Operating_Mode": (33, 1), # 0: position, 1: speed, 2: current, 3: PWM
|
||||
"P_Coefficient_Curr": (34, 1), # Current ring P proportional coefficient
|
||||
"I_Coefficient_Curr": (35, 1), # Current ring I integral coefficient
|
||||
# Address 36 undefined
|
||||
"Speed_P_Coefficient": (37, 1), # Speed closed-loop P proportional coefficient
|
||||
"Overcurrent_Protection_Time": (38, 1), # 0-254 (10ms per unit)
|
||||
"Speed_I_Coefficient": (39, 1), # Speed closed-loop I integral coefficient
|
||||
# SRAM control (40-55)
|
||||
"Torque_Enable": (40, 1), # 0: off, 1: on, 2: damping
|
||||
"Acceleration": (41, 1), # 0-254 (8.7 degrees/second² per unit)
|
||||
"Goal_Position": (42, 2), # -32767 to 32767 (0.087 degrees per unit)
|
||||
"Target_Torque": (44, 2), # -2047 to 2047 (6.5 mA per unit)
|
||||
"Goal_Velocity": (46, 2), # -32767 to 32767 (0.732 RPM per unit)
|
||||
"Torque_Limit": (48, 2), # 0-1000 (0.1% per unit)
|
||||
"P_Coefficient_Ring": (50, 1), # Motor position ring proportional coefficient
|
||||
"D_Coefficient_Ring": (51, 1), # Motor position ring differential coefficient
|
||||
"I_Coefficient_Ring": (52, 1), # Motor position ring integral coefficient
|
||||
"km": (53, 1), # 0: position+current dual loop, 1: position single loop
|
||||
# Address 54 undefined
|
||||
"Lock": (55, 1), # 0: close write lock, 1: open write lock
|
||||
# SRAM feedback (56-73) - read-only
|
||||
"Present_Position": (56, 2), # read-only - current absolute position
|
||||
"Present_Velocity": (58, 2), # read-only - current motor rotation speed
|
||||
"Present_Load": (60, 2), # read-only - current load (0.1% per unit)
|
||||
"Present_Voltage": (62, 1), # read-only - current voltage (0.1V per unit)
|
||||
"Present_Temperature": (63, 1), # read-only - current temperature (°C)
|
||||
"Async_Write_Flag": (64, 1), # read-only - async write instruction flag
|
||||
"Status": (65, 1), # read-only - servo status bit flags
|
||||
"Moving": (66, 1), # read-only - movement status flags
|
||||
"Target_Position": (67, 2), # read-only - current target position
|
||||
"Present_Current": (69, 2), # read-only - current motor phase current (6.5 mA per unit)
|
||||
# Address 71 undefined
|
||||
"Present_Bias": (73, 2), # read-only - current 0-point offset value
|
||||
# Factory parameters (77-86) - read-only
|
||||
"VFk_x10": (77, 1), # read-only - factory parameter
|
||||
"vKgI": (78, 1), # read-only - factory parameter
|
||||
"PFk_x10": (79, 1), # read-only - factory parameter
|
||||
"Moving_Velocity_Threshold": (80, 1), # read-only - factory parameter
|
||||
"DTs_ms": (81, 1), # read-only - factory parameter
|
||||
"eFk_x10": (82, 1), # read-only - factory parameter
|
||||
"Vk_ms": (83, 1), # read-only - factory parameter
|
||||
"Maximum_Velocity_Limit": (84, 1), # read-only - factory parameter
|
||||
"Maximum_Acceleration": (85, 1), # read-only - factory parameter
|
||||
"Acceleration_Multiplier": (86, 1), # read-only - factory parameter
|
||||
}
|
||||
|
||||
# HLS series baud rate table (same as STS/SMS series)
|
||||
HLS_SERIES_BAUDRATE_TABLE = {
|
||||
1_000_000: 0,
|
||||
500_000: 1,
|
||||
250_000: 2,
|
||||
128_000: 3,
|
||||
115_200: 4,
|
||||
76_800: 5, # Note: HLS documentation mentions 76800 instead of 57600
|
||||
57_600: 6,
|
||||
38_400: 7,
|
||||
}
|
||||
|
||||
STS_SMS_SERIES_BAUDRATE_TABLE = {
|
||||
1_000_000: 0,
|
||||
500_000: 1,
|
||||
@@ -181,6 +270,7 @@ MODEL_CONTROL_TABLE = {
|
||||
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
"scs0009": SCS_SERIES_CONTROL_TABLE,
|
||||
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
"hls3625": HLS_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
@@ -189,8 +279,9 @@ MODEL_RESOLUTION = {
|
||||
"scs_series": 1024,
|
||||
"sts3215": 4096,
|
||||
"sts3250": 4096,
|
||||
"sm8512bl": 65536,
|
||||
"sm8512bl": 4096,
|
||||
"scs0009": 1024,
|
||||
"hls3625": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
@@ -201,6 +292,7 @@ MODEL_BAUDRATE_TABLE = {
|
||||
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
|
||||
"hls3625": HLS_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
# Sign-Magnitude encoding bits
|
||||
@@ -210,6 +302,18 @@ STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Present_Velocity": 15,
|
||||
}
|
||||
|
||||
# HLS series sign-magnitude encoding bits
|
||||
HLS_SERIES_ENCODINGS_TABLE = {
|
||||
"Homing_Offset": 15, # BIT15 represents positive/negative direction
|
||||
"Goal_Position": 15, # BIT15 represents positive/negative direction
|
||||
"Target_Torque": 15, # BIT15 represents positive/negative direction in constant current mode
|
||||
"Goal_Velocity": 15, # BIT15 represents positive/negative direction in constant speed mode
|
||||
"Present_Position": 15, # BIT15 represents positive/negative direction
|
||||
"Present_Velocity": 15, # BIT15 represents positive/negative direction
|
||||
"Present_Current": 15, # BIT15 represents positive/negative direction
|
||||
"Present_Load": 10, # BIT10 represents positive/negative direction
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
@@ -218,6 +322,7 @@ MODEL_ENCODING_TABLE = {
|
||||
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"scs0009": {},
|
||||
"hls3625": HLS_SERIES_ENCODINGS_TABLE,
|
||||
}
|
||||
|
||||
SCAN_BAUDRATES = [
|
||||
@@ -239,6 +344,7 @@ MODEL_NUMBER_TABLE = {
|
||||
"sts3250": 2825,
|
||||
"sm8512bl": 11272,
|
||||
"scs0009": 1284,
|
||||
"hls3625": 3338,
|
||||
}
|
||||
|
||||
MODEL_PROTOCOL = {
|
||||
@@ -249,4 +355,5 @@ MODEL_PROTOCOL = {
|
||||
"sts3250": 0,
|
||||
"sm8512bl": 0,
|
||||
"scs0009": 1,
|
||||
"hls3625": 0, # Uses FT-SCS protocol
|
||||
}
|
||||
|
||||
@@ -83,6 +83,9 @@ class MotorNormMode(str, Enum):
|
||||
DEGREES = "degrees"
|
||||
|
||||
|
||||
COUNT_TO_DEG = 0.087 # 1 encoder count = 0.087 °
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorCalibration:
|
||||
id: int
|
||||
@@ -441,8 +444,8 @@ class MotorsBus(abc.ABC):
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
elif handshake:
|
||||
self._handshake()
|
||||
# elif handshake:
|
||||
# self._handshake()
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
@@ -586,7 +589,7 @@ class MotorsBus(abc.ABC):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self):
|
||||
def torque_disabled(self, motors: int | str | list[str] | None = None):
|
||||
"""Context-manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
@@ -596,11 +599,11 @@ class MotorsBus(abc.ABC):
|
||||
... # Safe operations here
|
||||
... pass
|
||||
"""
|
||||
self.disable_torque()
|
||||
self.disable_torque(motors)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque()
|
||||
self.enable_torque(motors)
|
||||
|
||||
def set_timeout(self, timeout_ms: int | None = None):
|
||||
"""Change the packet timeout used by the SDK.
|
||||
@@ -653,12 +656,13 @@ class MotorsBus(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||
"""Write calibration parameters to the motors and cache them.
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration parameters to the motors and optionally cache them.
|
||||
|
||||
Args:
|
||||
calibration_dict (dict[str, MotorCalibration]): Calibration obtained from
|
||||
:pymeth:`read_calibration` or crafted by the user.
|
||||
cache (bool, optional): Save the calibration to :pyattr:`calibration`. Defaults to True.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -710,9 +714,8 @@ class MotorsBus(abc.ABC):
|
||||
self.reset_calibration(motors)
|
||||
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
homing_offsets = self._get_half_turn_homings(actual_positions)
|
||||
for motor, offset in homing_offsets.items():
|
||||
self.write("Homing_Offset", motor, offset)
|
||||
|
||||
# Don't write to motors, just return the calculated offsets
|
||||
return homing_offsets
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -781,21 +784,32 @@ class MotorsBus(abc.ABC):
|
||||
motor = self._id_to_name(id_)
|
||||
min_ = self.calibration[motor].range_min
|
||||
max_ = self.calibration[motor].range_max
|
||||
homing_offset = self.calibration[motor].homing_offset
|
||||
drive_mode = self.apply_drive_mode and self.calibration[motor].drive_mode
|
||||
|
||||
if max_ == min_:
|
||||
raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.")
|
||||
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
if self.motors[motor].norm_mode is MotorNormMode.RANGE_M100_100:
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
||||
normalized_values[id_] = -norm if drive_mode else norm
|
||||
elif self.motors[motor].norm_mode is MotorNormMode.RANGE_0_100:
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
norm = ((bounded_val - min_) / (max_ - min_)) * 100
|
||||
normalized_values[id_] = 100 - norm if drive_mode else norm
|
||||
elif self.motors[motor].norm_mode is MotorNormMode.DEGREES:
|
||||
mid = (min_ + max_) / 2
|
||||
max_res = self.model_resolution_table[self._id_to_model(id_)] - 1
|
||||
normalized_values[id_] = (val - mid) * 360 / max_res
|
||||
# For motors without wrap-around handling
|
||||
# The homing offset becomes 0 degrees
|
||||
|
||||
# Calculate difference from homing position
|
||||
diff = val - homing_offset
|
||||
|
||||
# Convert to degrees
|
||||
deg = diff * COUNT_TO_DEG
|
||||
|
||||
# Apply drive mode if needed
|
||||
normalized_values[id_] = -deg if drive_mode else deg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -810,7 +824,9 @@ class MotorsBus(abc.ABC):
|
||||
motor = self._id_to_name(id_)
|
||||
min_ = self.calibration[motor].range_min
|
||||
max_ = self.calibration[motor].range_max
|
||||
homing_offset = self.calibration[motor].homing_offset
|
||||
drive_mode = self.apply_drive_mode and self.calibration[motor].drive_mode
|
||||
|
||||
if max_ == min_:
|
||||
raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.")
|
||||
|
||||
@@ -823,9 +839,22 @@ class MotorsBus(abc.ABC):
|
||||
bounded_val = min(100.0, max(0.0, val))
|
||||
unnormalized_values[id_] = int((bounded_val / 100) * (max_ - min_) + min_)
|
||||
elif self.motors[motor].norm_mode is MotorNormMode.DEGREES:
|
||||
mid = (min_ + max_) / 2
|
||||
max_res = self.model_resolution_table[self._id_to_model(id_)] - 1
|
||||
unnormalized_values[id_] = int((val * max_res / 360) + mid)
|
||||
# For motors without wrap-around, simple conversion back
|
||||
# Apply drive mode if needed
|
||||
val = -val if drive_mode else val
|
||||
|
||||
# Convert degrees to raw counts
|
||||
raw_counts = int(round(val / COUNT_TO_DEG))
|
||||
|
||||
# Add back the homing offset
|
||||
raw_counts_with_offset = raw_counts + homing_offset
|
||||
|
||||
# Ensure value stays within calibrated motor range
|
||||
# Use the calibration min/max if available
|
||||
if min_ is not None and max_ is not None:
|
||||
raw_counts_with_offset = max(min_, min(max_, raw_counts_with_offset))
|
||||
|
||||
unnormalized_values[id_] = raw_counts_with_offset
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -16,6 +16,5 @@ from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
@@ -107,7 +107,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
@@ -132,7 +132,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
@@ -485,12 +485,10 @@ class ACT(nn.Module):
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
if self.config.image_features:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
# NOTE: If modifying this section, verify on MPS devices that
|
||||
# gradients remain stable (no explosions or NaNs).
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
@@ -500,11 +498,10 @@ class ACT(nn.Module):
|
||||
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
|
||||
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
||||
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
||||
# Extend immediately instead of accumulating and concatenating
|
||||
# Convert to list to extend properly
|
||||
encoder_in_tokens.extend(list(cam_features))
|
||||
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
|
||||
@@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
@@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
||||
@@ -75,10 +74,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -100,8 +95,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
|
||||
@@ -149,7 +149,7 @@ class Normalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
@@ -224,7 +224,7 @@ class Unnormalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
|
||||
@@ -260,12 +260,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("Currently not implemented for PI0")
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
|
||||
@@ -192,12 +192,12 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("Currently not implemented for PI0FAST")
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class SACPolicy(
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
|
||||
@@ -413,6 +413,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
@@ -422,7 +423,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTConfig:
|
||||
r: int = 4
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.1
|
||||
target_modules: str = "q_proj,v_proj"
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla2")
|
||||
@dataclass
|
||||
class SmolVLA2Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512)
|
||||
|
||||
# Add empty images. Used by smolvla_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
proj_width: int = 480
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5 # 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
optimizer_lr_vlm: float = 0
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
checkpoint_path: str = None
|
||||
peft_method: str = ""
|
||||
peft_config: PEFTConfig = field(default_factory=PEFTConfig)
|
||||
peft_target_model: str = ""
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
attention_mode: str = "cross_attn"
|
||||
|
||||
prefix_length: int = -1
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
|
||||
num_vlm_layers: int = 16
|
||||
past_obs_keys: str = "image"
|
||||
add_local_special_image_tokens: bool = False
|
||||
|
||||
reverse_images_order: bool = False
|
||||
|
||||
state_to_prefix: bool = False
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
causal_action_attention_mask: bool = False
|
||||
|
||||
self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers)
|
||||
# self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
|
||||
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
|
||||
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
robot_type: str = ""
|
||||
|
||||
self_attn_only_actions: bool = False
|
||||
|
||||
causal_attention_on_history: bool = False
|
||||
|
||||
predict_relative_actions: bool = False
|
||||
relative_actions_mode: str = "first"
|
||||
|
||||
shuffle_camera_positions: bool = False
|
||||
vlm_img_size: int = -1
|
||||
|
||||
regression_loss: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,600 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
config = self.vlm.config
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
self.vlm = SmolVLMForConditionalGeneration(config=config)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
if num_vlm_layers > 0:
|
||||
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
|
||||
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
|
||||
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
|
||||
self.config = config
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
if "cross" in attention_mode:
|
||||
# Reshape qkv projections to have the same input dimension as the vlm
|
||||
for layer_idx in range(len(self.lm_expert.layers)):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
|
||||
self.num_attention_heads = self.config.text_config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.text_config.num_key_value_heads
|
||||
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_mode = attention_mode
|
||||
self.expert_hidden_size = lm_expert_config.hidden_size
|
||||
self.set_requires_grad()
|
||||
|
||||
def configure_peft(self, config):
|
||||
# return model
|
||||
self.peft_method = config.peft_method
|
||||
self.peft_target_model = config.peft_target_model
|
||||
if "lora" in self.peft_method:
|
||||
peft_config = config.peft_config
|
||||
target_modules = peft_config.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = target_modules.split(",")
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
|
||||
r=peft_config.r, # The rank of the low-rank adaptation
|
||||
lora_alpha=peft_config.lora_alpha, # Scaling factor
|
||||
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
|
||||
target_modules=target_modules, # The components where LoRA is applied
|
||||
exclude_modules=[
|
||||
"lm_expert",
|
||||
"model.lm_expert.model.layers",
|
||||
], # FIXME(mshukor): this does not work for now
|
||||
)
|
||||
self.lora_config = lora_config
|
||||
# Apply LoRA and ensure only LoRA parameters are trainable
|
||||
if "text" in self.peft_target_model:
|
||||
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
|
||||
else:
|
||||
self.vlm = get_peft_model(self.vlm, lora_config)
|
||||
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
|
||||
for name, param in self.vlm.named_parameters():
|
||||
if (
|
||||
"lora" in name and "text_model.model.layers.17" not in name
|
||||
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
def merge_lora_weights(self):
|
||||
"""
|
||||
Merge LoRA weights into the base model.
|
||||
"""
|
||||
if "text" in self.peft_target_model:
|
||||
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
|
||||
else:
|
||||
self.vlm = self.vlm.merge_and_unload()
|
||||
|
||||
def get_vlm_model(
|
||||
self,
|
||||
):
|
||||
if hasattr(self.vlm.model, "model"): # When using peft
|
||||
return self.vlm.model.model
|
||||
else:
|
||||
return self.vlm.model
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
for params in self.get_vlm_model().vision_model.parameters():
|
||||
params.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
for params in self.vlm.parameters():
|
||||
params.requires_grad = False
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any(k in name for k in frozen_layers):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
if "lm_head" in name:
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
patch_attention_mask = None
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
seq_len = query_states.shape[1]
|
||||
if seq_len < position_ids.shape[1]:
|
||||
_position_ids = position_ids[:, :seq_len]
|
||||
_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
else:
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(query_states, position_ids_)
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_output = attention_interface(
|
||||
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
return [att_output], past_key_values
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
query_states = apply_rope(query_state, position_id)
|
||||
key_states = apply_rope(key_state, position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
expert_position_id = position_ids
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
|
||||
|
||||
expert_input_shape = expert_hidden_states.shape[:-1]
|
||||
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
att_outputs.append(None)
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list:
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
for i in range(self.num_vlm_layers):
|
||||
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
|
||||
expert_layer = None
|
||||
else:
|
||||
expert_layer_index = i // multiple_of if multiple_of > 0 else i
|
||||
expert_layer = models[1].layers[expert_layer_index]
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
continue
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
att_out = att_output[:, start:end]
|
||||
out_emb = layer.self_attn.o_proj(att_out)
|
||||
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end if len(att_outputs) == 1 else 0
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.num_attention_heads
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -110,7 +110,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
@@ -124,14 +124,14 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.action_chunk_size),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ Example:
|
||||
python -m lerobot.record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480}}" \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.num_episodes=2 \
|
||||
@@ -33,6 +33,41 @@ python -m lerobot.record \
|
||||
# <- Policy optional if you want to record with a policy \
|
||||
# --policy.path=${HF_USER}/my_policy \
|
||||
```
|
||||
|
||||
Example with bilateral teleoperation:
|
||||
|
||||
```shell
|
||||
python -m lerobot.record \
|
||||
--robot.type=so101_follower_t \
|
||||
--robot.port=/dev/tty.usbmodem58760432961 \
|
||||
--robot.id=follower_arm_torque \
|
||||
--dataset.repo_id=pepijn/bilateral-teleop-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Wipe the table" \
|
||||
--biteleop=true \
|
||||
--teleop.type=so101_follower_t \
|
||||
--teleop.port=/dev/tty.usbmodem58760432571 \
|
||||
--teleop.id=leader_arm_torque \
|
||||
--dataset.fps=100 \
|
||||
--robot.cameras="{side: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 100}}" \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
Example Eval with bilateral teleoperation:
|
||||
```
|
||||
python -m lerobot.record \
|
||||
--robot.type=so101_follower_t \
|
||||
--robot.port=/dev/tty.usbmodem58760432961 \
|
||||
--robot.id=follower_arm_torque \
|
||||
--robot.cameras="{side: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 100}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=pepijn223/eval_bilateral-wipe-large \
|
||||
--dataset.single_task="Wipe the table" \
|
||||
--policy.path=pepijn223/bilateral-wipe-large-single \
|
||||
--dataset.fps=100 \
|
||||
--biteleop=true
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -57,14 +92,18 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
so101_follower_torque,
|
||||
)
|
||||
from lerobot.robots.so101_follower_torque import SO101FollowerT
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
@@ -87,6 +126,24 @@ from lerobot.utils.utils import (
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
|
||||
|
||||
def split_interleaved_action(vec, motors):
|
||||
"""
|
||||
vec : 1‑D tensor/array, length = 3*len(motors)
|
||||
motors : ['shoulder_pan', 'shoulder_lift', …]
|
||||
|
||||
returns : pos, vel, tau (three dicts keyed by joint name)
|
||||
"""
|
||||
pos = {}
|
||||
vel = {}
|
||||
tau = {}
|
||||
for i, j in enumerate(motors):
|
||||
base = 3 * i
|
||||
pos[j] = float(vec[base + 0])
|
||||
vel[j] = float(vec[base + 1])
|
||||
tau[j] = float(vec[base + 2])
|
||||
return pos, vel, tau
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
@@ -131,7 +188,7 @@ class RecordConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Whether to control the robot with a teleoperator
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
teleop: TeleoperatorConfig | RobotConfig | None = None
|
||||
# Whether to control the robot with a policy
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Display all cameras on screen
|
||||
@@ -140,6 +197,8 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Enable bilateral teleoperation with force feedback
|
||||
biteleop: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -164,15 +223,23 @@ def record_loop(
|
||||
events: dict,
|
||||
fps: int,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | List[Teleoperator] | None = None,
|
||||
teleop: Teleoperator | List[Teleoperator] | Robot | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
biteleop: bool = False,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
if biteleop and policy is None:
|
||||
if not isinstance(robot, SO101FollowerT):
|
||||
raise ValueError(
|
||||
"Bilateral teleoperation requires both robot and teleop to be of type SO101FollowerT"
|
||||
)
|
||||
logging.info("Bilateral teleoperation mode enabled")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list):
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
@@ -196,7 +263,11 @@ def record_loop(
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
|
||||
loop_count = 0
|
||||
rerun_log_freq = max(1, int(fps / 10))
|
||||
|
||||
while control_time_s is not None and timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
@@ -208,7 +279,67 @@ def record_loop(
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
|
||||
if policy is not None:
|
||||
if (
|
||||
biteleop
|
||||
and isinstance(robot, SO101FollowerT)
|
||||
and isinstance(teleop, SO101FollowerT)
|
||||
and policy is None
|
||||
):
|
||||
obs_f = observation # robot is the follower
|
||||
obs_l = teleop.get_observation()
|
||||
|
||||
pos_f = {j: obs_f[f"{j}.pos"] for j in robot.bus.motors}
|
||||
vel_f = {j: obs_f[f"{j}.vel"] for j in robot.bus.motors}
|
||||
tau_reaction_f = {j: obs_f[f"{j}.effort"] for j in robot.bus.motors}
|
||||
|
||||
pos_l = {j: obs_l[f"{j}.pos"] for j in teleop.bus.motors}
|
||||
vel_l = {j: obs_l[f"{j}.vel"] for j in teleop.bus.motors}
|
||||
acc_l = {j: obs_l[f"{j}.acc"] for j in teleop.bus.motors}
|
||||
tau_reaction_l = {j: obs_l[f"{j}.effort"] for j in teleop.bus.motors}
|
||||
|
||||
# Get control gains from robot
|
||||
kp_gains = robot.kp_gains
|
||||
kd_gains = robot.kd_gains
|
||||
kf_gains = robot.kf_gains
|
||||
|
||||
# Compute torque commands
|
||||
tau_cmd_f = [
|
||||
(
|
||||
kp_gains[j] * (pos_l[j] - pos_f[j]) # Position tracking
|
||||
+ kd_gains[j] * (vel_l[j] - vel_f[j]) # Velocity damping
|
||||
+ kf_gains[j] * (-tau_reaction_l[j] - tau_reaction_f[j])
|
||||
) # Force reflection
|
||||
for j in robot.bus.motors
|
||||
]
|
||||
|
||||
tau_cmd_l = [
|
||||
(
|
||||
kp_gains[j] * (pos_f[j] - pos_l[j]) # Position tracking
|
||||
+ kd_gains[j] * (vel_f[j] - vel_l[j]) # Velocity damping
|
||||
+ kf_gains[j] * (-tau_reaction_f[j] - tau_reaction_l[j])
|
||||
) # Force reflection
|
||||
for j in teleop.bus.motors
|
||||
]
|
||||
|
||||
action = {f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(robot.bus.motors)}
|
||||
teleop_action = {f"{m}.effort": tau_cmd_l[i] for i, m in enumerate(teleop.bus.motors)}
|
||||
teleop.send_action(teleop_action)
|
||||
robot.send_action(action)
|
||||
|
||||
# For bilateral teleoperation, create custom observation and action for dataset
|
||||
bilateral_action = {}
|
||||
for j in teleop.bus.motors:
|
||||
bilateral_action[f"{j}.pos"] = pos_l[j]
|
||||
bilateral_action[f"{j}.vel"] = vel_l[j]
|
||||
bilateral_action[f"{j}.acc"] = acc_l[j]
|
||||
bilateral_action[f"{j}.effort"] = -tau_reaction_l[j]
|
||||
|
||||
# Override the observation_frame and action for dataset recording
|
||||
if dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
action = bilateral_action
|
||||
|
||||
elif policy is not None and biteleop and isinstance(robot, SO101FollowerT):
|
||||
action_values = predict_action(
|
||||
observation_frame,
|
||||
policy,
|
||||
@@ -217,10 +348,57 @@ def record_loop(
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
pos_f = {j: observation[f"{j}.pos"] for j in robot.bus.motors}
|
||||
vel_f = {j: observation[f"{j}.vel"] for j in robot.bus.motors}
|
||||
tau_reaction_f = {j: observation[f"{j}.effort"] for j in robot.bus.motors}
|
||||
|
||||
# The model returns [pos1, pos2, …, vel1, vel2, …, tau1, tau2, …]
|
||||
motors = robot.bus.motors # 6 joints
|
||||
pos_l, vel_l, neg_tau_reaction_l = split_interleaved_action(
|
||||
action_values, motors
|
||||
) # The model is trained and returns the effort already as negative: -tau_reaction_l
|
||||
|
||||
kp, kd, kf = robot.kp_gains, robot.kd_gains, robot.kf_gains
|
||||
|
||||
# Compute torque command for the follower robot
|
||||
tau_cmd_f = [
|
||||
(
|
||||
kp[j] * (pos_l[j] - pos_f[j]) # Position tracking
|
||||
+ kd[j] * (vel_l[j] - vel_f[j]) # Velocity damping
|
||||
+ kf[j] * (neg_tau_reaction_l[j] - tau_reaction_f[j]) # Force reflection
|
||||
)
|
||||
for j in robot.bus.motors
|
||||
]
|
||||
|
||||
# Format action with calculated torques and send to robot
|
||||
action = {f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(robot.bus.motors)}
|
||||
robot.send_action(action)
|
||||
|
||||
bilateral_action = {}
|
||||
for j in robot.bus.motors:
|
||||
bilateral_action[f"{j}.pos"] = pos_l[j]
|
||||
bilateral_action[f"{j}.vel"] = vel_l[j]
|
||||
bilateral_action[f"{j}.effort"] = neg_tau_reaction_l[j]
|
||||
|
||||
# Override the observation_frame and action for dataset recording
|
||||
if dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
action = bilateral_action
|
||||
|
||||
elif policy is not None and not biteleop:
|
||||
action_values = predict_action(
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
elif policy is None and isinstance(teleop, Teleoperator) and not biteleop:
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
elif policy is None and isinstance(teleop, list) and not biteleop:
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
@@ -239,31 +417,62 @@ def record_loop(
|
||||
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
sent_action = robot.send_action(action)
|
||||
if not biteleop:
|
||||
sent_action = robot.send_action(action)
|
||||
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
# For bilateral teleoperation, use the bilateral_action (leader pos & torque)
|
||||
# For other modes, use sent_action as usual
|
||||
if biteleop and isinstance(robot, SO101FollowerT):
|
||||
action_frame = build_dataset_frame(dataset.features, action, prefix="action")
|
||||
else:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
if display_data and loop_count % rerun_log_freq == 0:
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
loop_count += 1
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.display_data:
|
||||
_init_rerun(session_name="recording")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
|
||||
|
||||
if cfg.biteleop and cfg.teleop is not None:
|
||||
print("Bilateral teleoperation enabled")
|
||||
# For bilateral teleoperation, both arms must be SO101FollowerT robots
|
||||
from lerobot.robots.so101_follower_torque.config_so101_follower_t import SO101FollowerTConfig
|
||||
|
||||
# Check if teleop config has the right type
|
||||
if cfg.teleop.type != "so101_follower_t":
|
||||
raise ValueError("Bilateral teleoperation requires teleop.type to be 'so101_follower_t'")
|
||||
|
||||
port = getattr(cfg.teleop, "port", None)
|
||||
if port is None:
|
||||
raise ValueError("Bilateral teleoperation requires teleop.port to be specified")
|
||||
|
||||
teleop_robot_config = SO101FollowerTConfig(
|
||||
port=port,
|
||||
id=getattr(cfg.teleop, "id", "leader_arm_torque"),
|
||||
cameras=getattr(cfg.teleop, "cameras", {}),
|
||||
disable_torque_on_disconnect=getattr(cfg.teleop, "disable_torque_on_disconnect", True),
|
||||
)
|
||||
|
||||
teleop = SO101FollowerT(teleop_robot_config)
|
||||
else:
|
||||
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
||||
@@ -317,6 +526,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
biteleop=cfg.biteleop,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
@@ -333,6 +543,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
biteleop=cfg.biteleop,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -25,6 +25,17 @@ python -m lerobot.replay \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
|
||||
Biteleop example:
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
--robot.type=so101_follower_t \
|
||||
--robot.port=/dev/tty.usbmodem58760432961 \
|
||||
--robot.id=follower_arm_torque \
|
||||
--dataset.repo_id=pepijn223/bilateral-wipe-large \
|
||||
--dataset.episode=10 \
|
||||
--biteleop=true
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -39,11 +50,14 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
so101_follower_torque,
|
||||
)
|
||||
from lerobot.robots.so101_follower_torque import SO101FollowerT
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -69,6 +83,8 @@ class ReplayConfig:
|
||||
dataset: DatasetReplayConfig
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Use biteleop to replay the dataset
|
||||
biteleop: bool = False
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@@ -79,22 +95,70 @@ def replay(cfg: ReplayConfig):
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if cfg.biteleop:
|
||||
if not isinstance(robot, SO101FollowerT):
|
||||
raise ValueError(
|
||||
"Bilateral teleoperation replay requires the robot to be of type SO101FollowerT."
|
||||
)
|
||||
log_say("Bilateral teleoperation replay enabled.", cfg.play_sounds)
|
||||
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
|
||||
start_time_all = time.perf_counter()
|
||||
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
action = {}
|
||||
action_from_ds_array = actions[idx]["action"]
|
||||
action_from_ds = {}
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
action[name] = action_array[i]
|
||||
action_from_ds[name] = action_from_ds_array[i]
|
||||
|
||||
robot.send_action(action)
|
||||
# Bilateral teleoperation
|
||||
if cfg.biteleop:
|
||||
obs_f = robot.get_observation()
|
||||
pos_f = {j: obs_f[f"{j}.pos"] for j in robot.bus.motors}
|
||||
vel_f = {j: obs_f[f"{j}.vel"] for j in robot.bus.motors}
|
||||
tau_reaction_f = {j: obs_f[f"{j}.effort"] for j in robot.bus.motors}
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
pos_l = {j: action_from_ds[f"{j}.pos"] for j in robot.bus.motors}
|
||||
vel_l = {j: action_from_ds[f"{j}.vel"] for j in robot.bus.motors}
|
||||
# The saved effort in dataset is -tau_reaction_l
|
||||
neg_tau_reaction_l = {j: action_from_ds[f"{j}.effort"] for j in robot.bus.motors}
|
||||
|
||||
# Get control gains from the robot instance
|
||||
kp_gains = robot.kp_gains
|
||||
kd_gains = robot.kd_gains
|
||||
kf_gains = robot.kf_gains
|
||||
|
||||
# Compute torque command for the follower robot
|
||||
tau_cmd_f = [
|
||||
(
|
||||
kp_gains[j] * (pos_l[j] - pos_f[j]) # Position tracking
|
||||
+ kd_gains[j] * (vel_l[j] - vel_f[j]) # Velocity damping
|
||||
+ kf_gains[j] * (neg_tau_reaction_l[j] - tau_reaction_f[j]) # Force reflection
|
||||
)
|
||||
for j in robot.bus.motors
|
||||
]
|
||||
|
||||
# Format action with calculated torques and send to robot
|
||||
action_to_send = {f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(robot.bus.motors)}
|
||||
robot.send_action(action_to_send)
|
||||
else:
|
||||
# Original logic for standard position-based replay
|
||||
robot.send_action(action_from_ds)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
total_time = time.perf_counter() - start_time_all
|
||||
actual_fps = idx / total_time if total_time > 0 else float("inf")
|
||||
logging.info(f"Average FPS achieved over episode: {actual_fps:.2f}")
|
||||
log_say(f"Average FPS achieved: {actual_fps:.2f}", cfg.play_sounds)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
|
||||
3
src/lerobot/robots/hope_jr/__init__.py
Normal file
3
src/lerobot/robots/hope_jr/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig
|
||||
from .hope_jr_arm import HopeJrArm
|
||||
from .hope_jr_hand import HopeJrHand
|
||||
51
src/lerobot/robots/hope_jr/config_hope_jr.py
Normal file
51
src/lerobot/robots/hope_jr/config_hope_jr.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("hope_jr_hand")
|
||||
@dataclass
|
||||
class HopeJrHandConfig(RobotConfig):
|
||||
port: str # Port to connect to the hand
|
||||
side: str # "left" / "right"
|
||||
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.side not in ["right", "left"]:
|
||||
raise ValueError(self.side)
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("hope_jr_arm")
|
||||
@dataclass
|
||||
class HopeJrArmConfig(RobotConfig):
|
||||
port: str # Port to connect to the hand
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
268
src/lerobot/robots/hope_jr/hope_jr.mdx
Normal file
268
src/lerobot/robots/hope_jr/hope_jr.mdx
Normal file
@@ -0,0 +1,268 @@
|
||||
# HopeJR
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- [Hardware Setup](https://github.com/TheRobotStudio/HOPEJr)
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
|
||||
|
||||
Install LeRobot with HopeJR dependencies:
|
||||
```bash
|
||||
pip install -e ".[hopejr]"
|
||||
```
|
||||
|
||||
## Device Configuration
|
||||
|
||||
Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
```
|
||||
|
||||
This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
|
||||
|
||||
## Step 1: Calibration
|
||||
|
||||
Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibration files will be saved in `~/.cache/huggingface/lerobot/calibration`
|
||||
|
||||
### 1.1 Calibrate Robot Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
--robot.side=right
|
||||
```
|
||||
|
||||
When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows:
|
||||
|
||||
**Thumb**:
|
||||
- **CMC**: base joint connecting thumb to hand
|
||||
- **MCP**: knuckle joint
|
||||
- **PIP**: first finger joint
|
||||
- **DIP** : fingertip joint
|
||||
|
||||
**Index, Middle, Ring, and Pinky fingers**:
|
||||
- **Radial flexor**: Moves base of finger towards the thumb
|
||||
- **Ulnar flexor**: Moves base of finger towards the pinky
|
||||
- **PIP/DIP**: Flexes the distal and proximal phalanx of the finger
|
||||
|
||||
Each one of these will need to be calibrated individually via the GUI.
|
||||
Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/calibration_gui_1.png"
|
||||
alt="Setting boundaries in the hand calibration GUI"
|
||||
title="Setting boundaries in the hand calibration GUI"
|
||||
width="100%">
|
||||
</img>
|
||||
</p>
|
||||
|
||||
Use the calibration interface to set the range boundaries for each joint as shown above.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/calibration_gui_2.png"
|
||||
alt="Saving calibration values"
|
||||
title="Saving calibration values"
|
||||
width="100%">
|
||||
</img>
|
||||
</p>
|
||||
|
||||
Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors.
|
||||
|
||||
### 1.2 Calibrate Teleoperator Glove
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
--teleop.type=homunculus_glove \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=red \
|
||||
--teleop.side=right
|
||||
```
|
||||
|
||||
Move each finger through its full range of motion, starting from the thumb.
|
||||
|
||||
```
|
||||
Move thumb through its entire range of motion.
|
||||
Recording positions. Press ENTER to stop...
|
||||
|
||||
-------------------------------------------
|
||||
NAME | MIN | POS | MAX
|
||||
thumb_cmc | 1790 | 1831 | 1853
|
||||
thumb_mcp | 1497 | 1514 | 1528
|
||||
thumb_pip | 1466 | 1496 | 1515
|
||||
thumb_dip | 1463 | 1484 | 1514
|
||||
```
|
||||
|
||||
Continue with each finger:
|
||||
|
||||
```
|
||||
Move middle through its entire range of motion.
|
||||
Recording positions. Press ENTER to stop...
|
||||
|
||||
-------------------------------------------
|
||||
NAME | MIN | POS | MAX
|
||||
middle_mcp_abduction | 1598 | 1718 | 1820
|
||||
middle_mcp_flexion | 1512 | 1658 | 2136
|
||||
middle_dip | 1484 | 1500 | 1547
|
||||
```
|
||||
|
||||
Once calibration is complete, the system will save the calibration to `/Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_glove/red.json`
|
||||
|
||||
### 1.3 Calibrate Robot Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white
|
||||
```
|
||||
|
||||
This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows:
|
||||
- **Shoulder**: pitch, yaw, and roll
|
||||
- **Elbow**: flex
|
||||
- **Wrist**: pitch, yaw, and roll
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/calibration_gui_2.png"
|
||||
alt="Setting boundaries in the arm calibration GUI"
|
||||
title="Setting boundaries in the arm calibration GUI"
|
||||
width="100%">
|
||||
</img>
|
||||
</p>
|
||||
|
||||
Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration.
|
||||
|
||||
### 1.4 Calibrate Teleoperator Exoskeleton
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
--teleop.type=homunculus_arm \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=black
|
||||
```
|
||||
|
||||
The exoskeleton allows one to control the robot arm. During calibration, you'll be prompted to move all joints through their full range of motion:
|
||||
|
||||
```
|
||||
Move all joints through their entire range of motion.
|
||||
Recording positions. Press ENTER to stop...
|
||||
|
||||
-------------------------------------------
|
||||
-------------------------------------------
|
||||
NAME | MIN | POS | MAX
|
||||
shoulder_pitch | 586 | 736 | 895
|
||||
shoulder_yaw | 1257 | 1374 | 1390
|
||||
shoulder_roll | 449 | 1034 | 2564
|
||||
elbow_flex | 3023 | 3117 | 3134
|
||||
wrist_roll | 3073 | 3096 | 3147
|
||||
wrist_yaw | 2143 | 2171 | 2185
|
||||
wrist_pitch | 1975 | 1993 | 2074
|
||||
Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_arm/black.json
|
||||
```
|
||||
|
||||
## Step 2: Teleoperation
|
||||
|
||||
Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions:
|
||||
|
||||
### Hand
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
--robot.side=right \
|
||||
--teleop.type=homunculus_glove \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=red \
|
||||
--teleop.side=right \
|
||||
--display_data=true \
|
||||
--fps=30
|
||||
```
|
||||
|
||||
### Arm
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white \
|
||||
--teleop.type=homunculus_arm \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=black \
|
||||
--display_data=true \
|
||||
--fps=30
|
||||
```
|
||||
|
||||
## Step 3: Record, Replay, Train
|
||||
|
||||
Record, Replay and Train with Hope-JR is still experimental.
|
||||
|
||||
### Record
|
||||
|
||||
This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=homunculus_glove \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
### Replay
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
### Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
|
||||
This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
176
src/lerobot/robots/hope_jr/hope_jr_arm.py
Normal file
176
src/lerobot/robots/hope_jr/hope_jr_arm.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_hope_jr import HopeJrArmConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HopeJrArm(Robot):
|
||||
config_class = HopeJrArmConfig
|
||||
name = "hope_jr_arm"
|
||||
|
||||
def __init__(self, config: HopeJrArmConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pitch": Motor(1, "sm8512bl", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_yaw": Motor(2, "sts3250", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_roll": Motor(3, "sts3250", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_flex": Motor(4, "sts3250", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_roll": Motor(5, "sts3250", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_yaw": Motor(6, "sts3250", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_pitch": Motor(7, "sts3250", MotorNormMode.RANGE_M100_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
# HACK
|
||||
self.shoulder_pitch = "shoulder_pitch"
|
||||
self.other_motors = [m for m in self.bus.motors if m != "shoulder_pitch"]
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect(handshake=False)
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
|
||||
# Connect the cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self, limb_name: str = None) -> None:
|
||||
groups = {
|
||||
"all": list(self.bus.motors.keys()),
|
||||
"shoulder": ["shoulder_pitch", "shoulder_yaw", "shoulder_roll"],
|
||||
"elbow": ["elbow_flex"],
|
||||
"wrist": ["wrist_roll", "wrist_yaw", "wrist_pitch"],
|
||||
}
|
||||
|
||||
self.calibration = RangeFinderGUI(self.bus, groups).run()
|
||||
self._save_calibration()
|
||||
print("Calibration saved to", self.calibration_fpath)
|
||||
|
||||
def configure(self) -> None:
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.configure_motors(maximum_acceleration=30, acceleration=30)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
# TODO: add docstring
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
obs_dict[self.shoulder_pitch] = self.bus.read("Present_Position", self.shoulder_pitch)
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.bus.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
200
src/lerobot/robots/hope_jr/hope_jr_hand.py
Normal file
200
src/lerobot/robots/hope_jr/hope_jr_hand.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_hope_jr import HopeJrHandConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RIGHT_HAND_INVERSIONS = [
|
||||
"thumb_mcp",
|
||||
"thumb_dip",
|
||||
"index_ulnar_flexor",
|
||||
"middle_ulnar_flexor",
|
||||
"ring_ulnar_flexor",
|
||||
"ring_pip_dip",
|
||||
"pinky_ulnar_flexor",
|
||||
"pinky_pip_dip",
|
||||
]
|
||||
|
||||
LEFT_HAND_INVERSIONS = [
|
||||
"thumb_cmc",
|
||||
"thumb_mcp",
|
||||
"thumb_dip",
|
||||
"index_radial_flexor",
|
||||
"index_pip_dip",
|
||||
"middle_radial_flexor",
|
||||
"middle_pip_dip",
|
||||
"ring_radial_flexor",
|
||||
"ring_pip_dip",
|
||||
"pinky_radial_flexor",
|
||||
# "pinky_pip_dip",
|
||||
]
|
||||
|
||||
|
||||
class HopeJrHand(Robot):
|
||||
config_class = HopeJrHandConfig
|
||||
name = "hope_jr_hand"
|
||||
|
||||
def __init__(self, config: HopeJrHandConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
# Thumb
|
||||
"thumb_cmc": Motor(1, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"thumb_mcp": Motor(2, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"thumb_pip": Motor(3, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"thumb_dip": Motor(4, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
# Index
|
||||
"index_radial_flexor": Motor(5, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"index_ulnar_flexor": Motor(6, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"index_pip_dip": Motor(7, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
# Middle
|
||||
"middle_radial_flexor": Motor(8, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"middle_ulnar_flexor": Motor(9, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"middle_pip_dip": Motor(10, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
# Ring
|
||||
"ring_radial_flexor": Motor(11, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"ring_ulnar_flexor": Motor(12, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"ring_pip_dip": Motor(13, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
# Pinky
|
||||
"pinky_radial_flexor": Motor(14, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"pinky_ulnar_flexor": Motor(15, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
"pinky_pip_dip": Motor(16, "scs0009", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
protocol_version=1,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.inverted_motors = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
|
||||
# Connect the cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
fingers = {}
|
||||
for finger in ["thumb", "index", "middle", "ring", "pinky"]:
|
||||
fingers[finger] = [motor for motor in self.bus.motors if motor.startswith(finger)]
|
||||
|
||||
self.calibration = RangeFinderGUI(self.bus, fingers).run()
|
||||
for motor in self.inverted_motors:
|
||||
self.calibration[motor].drive_mode = 1
|
||||
self._save_calibration()
|
||||
print("Calibration saved to", self.calibration_fpath)
|
||||
|
||||
def configure(self) -> None:
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.configure_motors()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
# TODO: add docstring
|
||||
for motor in self.bus.motors:
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
start = time.perf_counter()
|
||||
for motor in self.bus.motors:
|
||||
obs_dict[f"{motor}.pos"] = self.bus.read("Present_Position", motor)
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
2
src/lerobot/robots/so101_follower_torque/__init__.py
Normal file
2
src/lerobot/robots/so101_follower_torque/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .config_so101_follower_t import SO101FollowerTConfig
|
||||
from .so101_follower_t import SO101FollowerT
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so101_follower_t")
|
||||
@dataclass
|
||||
class SO101FollowerTConfig(RobotConfig):
|
||||
# Port to connect to the arm
|
||||
port: str
|
||||
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
553
src/lerobot/robots/so101_follower_torque/so101_follower_t.py
Normal file
553
src/lerobot/robots/so101_follower_torque/so101_follower_t.py
Normal file
@@ -0,0 +1,553 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
from scipy.signal import butter, lfilter
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_so101_follower_t import SO101FollowerTConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO101FollowerT(Robot):
|
||||
"""
|
||||
SO-101 Arm with HLS3625 motors with current control.
|
||||
"""
|
||||
|
||||
config_class = SO101FollowerTConfig
|
||||
name = "so101_follower_t"
|
||||
|
||||
_CURRENT_STEP_A: float = 6.5e-3 # 6.5 mA per register LSB #http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
|
||||
_KT_NM_PER_AMP: float = 0.814 # Torque constant Kt [N·m/A] #https://www.feetechrc.com/811177.html
|
||||
_MAX_CURRENT_A: float = 4.0 # Safe driver limit
|
||||
|
||||
# Position gains
|
||||
_KP_GAINS = {
|
||||
"shoulder_pan": 5.0,
|
||||
"shoulder_lift": 7.0,
|
||||
"elbow_flex": 7.0,
|
||||
"wrist_flex": 5.0,
|
||||
"wrist_roll": 5.0,
|
||||
"gripper": 5.0,
|
||||
}
|
||||
|
||||
# Velocity gains
|
||||
_KD_GAINS = {
|
||||
"shoulder_pan": 0.4,
|
||||
"shoulder_lift": 0.6,
|
||||
"elbow_flex": 0.6,
|
||||
"wrist_flex": 0.4,
|
||||
"wrist_roll": 0.4,
|
||||
"gripper": 0.4,
|
||||
}
|
||||
|
||||
# Force gains
|
||||
_KF_GAINS = {
|
||||
"shoulder_pan": 0.05,
|
||||
"shoulder_lift": 0.05,
|
||||
"elbow_flex": 0.05,
|
||||
"wrist_flex": 0.05,
|
||||
"wrist_roll": 0.05,
|
||||
"gripper": 0.05,
|
||||
}
|
||||
|
||||
# Viscous friction coefficient
|
||||
_FRICTION_VISCOUS = {
|
||||
"shoulder_pan": 0.05,
|
||||
"shoulder_lift": 0.08,
|
||||
"elbow_flex": 0.05,
|
||||
"wrist_flex": 0.05,
|
||||
"wrist_roll": 0.05,
|
||||
"gripper": 0.05,
|
||||
}
|
||||
|
||||
# Coulomb/static friction
|
||||
_FRICTION_COULOMB = {
|
||||
"shoulder_pan": 0.15,
|
||||
"shoulder_lift": 0.25,
|
||||
"elbow_flex": 0.25,
|
||||
"wrist_flex": 0.20,
|
||||
"wrist_roll": 0.20,
|
||||
"gripper": 0.20,
|
||||
}
|
||||
|
||||
def __init__(self, config: SO101FollowerTConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
if self.calibration_fpath.is_file() and not self.calibration:
|
||||
self._load_calibration()
|
||||
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "hls3625", MotorNormMode.DEGREES),
|
||||
"shoulder_lift": Motor(2, "hls3625", MotorNormMode.DEGREES),
|
||||
"elbow_flex": Motor(3, "hls3625", MotorNormMode.DEGREES),
|
||||
"wrist_flex": Motor(4, "hls3625", MotorNormMode.DEGREES),
|
||||
"wrist_roll": Motor(5, "hls3625", MotorNormMode.DEGREES),
|
||||
"gripper": Motor(6, "hls3625", MotorNormMode.DEGREES),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.pin_robot = pin.RobotWrapper.BuildFromURDF("urdf/so101_new_calib.urdf", "urdf")
|
||||
|
||||
flip = {
|
||||
"shoulder_pan": True,
|
||||
"shoulder_lift": True,
|
||||
"elbow_flex": True,
|
||||
"wrist_flex": True,
|
||||
"wrist_roll": True,
|
||||
"gripper": True,
|
||||
}
|
||||
self.torque_sign = {n: (-1.0 if flip[n] else 1.0) for n in self.bus.motors}
|
||||
|
||||
self._prev_pos_rad: dict[str, float] | None = None
|
||||
self._prev_vel_rad: dict[str, float] | None = None
|
||||
self._prev_t: float | None = None
|
||||
|
||||
# Butterworth low-pass filter parameters
|
||||
self._cutoff_freq = 10.0 # Hz, cutoff frequency for the filter
|
||||
self._filter_order = 2 # Filter order
|
||||
self._sampling_freq = 100.0 # Hz, (control loop frequency)
|
||||
|
||||
nyquist_freq = self._sampling_freq / 2
|
||||
normalized_cutoff = self._cutoff_freq / nyquist_freq
|
||||
self._b, self._a = butter(self._filter_order, normalized_cutoff, btype="low")
|
||||
|
||||
# History buffers
|
||||
self._pos_history = {m: collections.deque(maxlen=20) for m in self.bus.motors}
|
||||
self._vel_raw_history = {m: collections.deque(maxlen=20) for m in self.bus.motors}
|
||||
self._time_history = collections.deque(maxlen=20)
|
||||
|
||||
self._last_observation = None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
d: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
d[f"{motor}.pos"] = float
|
||||
d[f"{motor}.vel"] = float
|
||||
d[f"{motor}.effort"] = float
|
||||
return d
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
d: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
d[f"{motor}.pos"] = float
|
||||
d[f"{motor}.vel"] = float
|
||||
d[f"{motor}.effort"] = float
|
||||
return d
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@property
|
||||
def kp_gains(self) -> dict[str, float]:
|
||||
"""Position control gains [Nm/rad] for bilateral teleoperation"""
|
||||
return self._KP_GAINS.copy()
|
||||
|
||||
@property
|
||||
def kd_gains(self) -> dict[str, float]:
|
||||
"""Velocity control gains [Nm⋅s/rad] for bilateral teleoperation"""
|
||||
return self._KD_GAINS.copy()
|
||||
|
||||
@property
|
||||
def kf_gains(self) -> dict[str, float]:
|
||||
"""Force control gains for bilateral teleoperation"""
|
||||
return self._KF_GAINS.copy()
|
||||
|
||||
@property
|
||||
def friction_viscous(self) -> dict[str, float]:
|
||||
"""Viscous friction coefficients [Nm⋅s/rad] for friction compensation"""
|
||||
return self._FRICTION_VISCOUS.copy()
|
||||
|
||||
@property
|
||||
def friction_coulomb(self) -> dict[str, float]:
|
||||
"""Coulomb friction coefficients [Nm] for friction compensation"""
|
||||
return self._FRICTION_COULOMB.copy()
|
||||
|
||||
def set_butterworth_params(self, cutoff_freq: float = 10.0, order: int = 2) -> None:
|
||||
"""Configure Butterworth low-pass filter parameters for velocity/acceleration estimation.
|
||||
|
||||
Args:
|
||||
cutoff_freq: Cutoff frequency in Hz (default: 10 Hz)
|
||||
order: Filter order (default: 2)
|
||||
"""
|
||||
if cutoff_freq <= 0:
|
||||
raise ValueError("Cutoff frequency must be positive")
|
||||
if cutoff_freq >= self._sampling_freq / 2:
|
||||
raise ValueError(
|
||||
f"Cutoff frequency must be less than Nyquist frequency ({self._sampling_freq / 2} Hz)"
|
||||
)
|
||||
if order < 1:
|
||||
raise ValueError("Filter order must be at least 1")
|
||||
|
||||
self._cutoff_freq = cutoff_freq
|
||||
self._filter_order = order
|
||||
|
||||
nyquist_freq = self._sampling_freq / 2
|
||||
normalized_cutoff = self._cutoff_freq / nyquist_freq
|
||||
self._b, self._a = butter(self._filter_order, normalized_cutoff, btype="low")
|
||||
|
||||
# Clear buffers
|
||||
for m in self.bus.motors:
|
||||
self._pos_history[m].clear()
|
||||
self._vel_raw_history[m].clear()
|
||||
self._time_history.clear()
|
||||
|
||||
logger.info(f"Butterworth filter updated: cutoff_freq={cutoff_freq} Hz, order={order}")
|
||||
|
||||
def _current_to_torque_nm(self, raw: dict[str, Any]) -> dict[str, float]:
|
||||
"""Convert "Present_Current" register counts (±2047) → torque [Nm].
|
||||
Values are clamped to ±3A before conversion for protection.
|
||||
"""
|
||||
max_cnt = int(round(self._MAX_CURRENT_A / self._CURRENT_STEP_A)) # ≈ 462
|
||||
coef = self._CURRENT_STEP_A * self._KT_NM_PER_AMP
|
||||
return {k: self.torque_sign[k] * max(-max_cnt, min(max_cnt, v)) * coef for k, v in raw.items()}
|
||||
|
||||
def _torque_nm_to_current(self, torque: dict[str, float]) -> dict[str, int]:
|
||||
"""Convert torque [Nm] to register counts, clamped to ±3A (2.44 Nm)."""
|
||||
inv_coef = 1.0 / (self._CURRENT_STEP_A * self._KT_NM_PER_AMP)
|
||||
max_cnt = int(round(self._MAX_CURRENT_A / self._CURRENT_STEP_A))
|
||||
counts = {}
|
||||
for k, τ in torque.items():
|
||||
cnt = τ * self.torque_sign[k] * inv_coef
|
||||
cnt = max(-max_cnt, min(max_cnt, cnt))
|
||||
counts[k] = int(round(cnt))
|
||||
return counts
|
||||
|
||||
def _deg_to_rad(self, deg: dict[str, float | int]) -> dict[str, float]:
|
||||
"""Degrees to radians."""
|
||||
return {m: np.deg2rad(float(v)) for m, v in deg.items()}
|
||||
|
||||
def _gravity_from_q(self, q_rad: dict[str, float]) -> dict[str, float]:
|
||||
"""
|
||||
Compute g(q) [N m] for all joints in the robot.
|
||||
The order of joints in the URDF matches self.bus.motors.
|
||||
"""
|
||||
q = np.zeros(self.pin_robot.model.nq)
|
||||
for i, motor_name in enumerate(self.bus.motors):
|
||||
q[i] = q_rad[motor_name]
|
||||
|
||||
g = pin.computeGeneralizedGravity(self.pin_robot.model, self.pin_robot.data, q)
|
||||
|
||||
return {motor_name: float(g[i]) for i, motor_name in enumerate(self.bus.motors)}
|
||||
|
||||
def _inertia_from_q_dq(
|
||||
self, q_rad: dict[str, float], dq_rad: dict[str, float], ddq_rad: dict[str, float]
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute inertia torques τ_inertia = M(q) * ddq directly from URDF model.
|
||||
"""
|
||||
q = np.zeros(self.pin_robot.model.nq)
|
||||
dq = np.zeros(self.pin_robot.model.nv)
|
||||
ddq = np.zeros(self.pin_robot.model.nv)
|
||||
|
||||
for i, motor_name in enumerate(self.bus.motors):
|
||||
q[i] = q_rad[motor_name]
|
||||
dq[i] = dq_rad[motor_name]
|
||||
ddq[i] = ddq_rad[motor_name]
|
||||
|
||||
# Compute mass matrix M(q)
|
||||
mass_matrix = pin.crba(self.pin_robot.model, self.pin_robot.data, q)
|
||||
|
||||
# Compute inertia torques: τ_inertia = M(q) * ddq
|
||||
tau_inertia = mass_matrix @ ddq
|
||||
|
||||
return {motor_name: float(tau_inertia[i]) for i, motor_name in enumerate(self.bus.motors)}
|
||||
|
||||
def _compute_model_based_disturbance(
|
||||
self,
|
||||
q_rad: dict[str, float],
|
||||
dq_rad: dict[str, float],
|
||||
ddq_rad: dict[str, float],
|
||||
tau_measured: dict[str, float],
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute disturbance torques using direct model-based approach:
|
||||
τ_disturbance = τ_measured - τ_gravity - τ_inertia - τ_friction
|
||||
|
||||
Args:
|
||||
include_friction: If True, also removes friction from the disturbance calculation
|
||||
"""
|
||||
tau_gravity = self._gravity_from_q(q_rad)
|
||||
tau_inertia = self._inertia_from_q_dq(q_rad, dq_rad, ddq_rad)
|
||||
|
||||
# Compute disturbance
|
||||
tau_disturbance = {}
|
||||
tau_friction = {}
|
||||
for motor_name in self.bus.motors:
|
||||
tau_dist = tau_measured[motor_name] - tau_gravity[motor_name] - tau_inertia[motor_name]
|
||||
|
||||
# Calculate friction torque
|
||||
omega = dq_rad[motor_name]
|
||||
tau_friction_motor = self._FRICTION_VISCOUS[motor_name] * omega + self._FRICTION_COULOMB[
|
||||
motor_name
|
||||
] * (1.0 if omega > 0.01 else -1.0 if omega < -0.01 else 0.0)
|
||||
# Apply torque sign correction
|
||||
tau_friction_motor = -tau_friction_motor
|
||||
tau_friction[motor_name] = tau_friction_motor
|
||||
tau_dist -= tau_friction_motor
|
||||
|
||||
tau_disturbance[motor_name] = tau_dist
|
||||
|
||||
return tau_disturbance
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Ensure calibration is loaded from file if it exists
|
||||
if self.calibration_fpath.is_file() and not self.calibration:
|
||||
self._load_calibration()
|
||||
# Update the bus with the loaded calibration
|
||||
self.bus.calibration = self.calibration
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
# Check if calibration file exists and is loaded
|
||||
return self.calibration_fpath.is_file() and bool(self.calibration)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, 2, num_retry=2) # Set to current mode
|
||||
|
||||
input(f"Move {self} to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.bus.set_half_turn_homings()
|
||||
|
||||
print(
|
||||
"Move all joints sequentially through their entire ranges "
|
||||
"of motion.\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
range_mins, range_maxes = self.bus.record_ranges_of_motion()
|
||||
|
||||
self.calibration = {}
|
||||
for motor, m in self.bus.motors.items():
|
||||
self.calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=int(homing_offsets[motor]),
|
||||
range_min=int(range_mins[motor]),
|
||||
range_max=int(range_maxes[motor]),
|
||||
)
|
||||
|
||||
# Update the bus calibration with the new values
|
||||
self.bus.calibration = self.calibration
|
||||
# Save calibration to file only
|
||||
self._save_calibration()
|
||||
print("Calibration saved to", self.calibration_fpath)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus.disable_torque() # here was issue at startup previously
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, 2, num_retry=2) # Set to current mode
|
||||
self.bus.write("Present_Current", motor, 0, normalize=False, num_retry=5)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
t_now = time.perf_counter()
|
||||
|
||||
# Position
|
||||
pos_deg = self.bus.sync_read("Present_Position", num_retry=5)
|
||||
pos_rad = self._deg_to_rad(pos_deg)
|
||||
|
||||
# Store position and time history
|
||||
for m in pos_rad:
|
||||
self._pos_history[m].append(pos_rad[m])
|
||||
self._time_history.append(t_now)
|
||||
|
||||
# Calculate raw velocity
|
||||
vel_rad_raw = {}
|
||||
if self._prev_pos_rad is None or self._prev_t is None:
|
||||
vel_rad_raw = dict.fromkeys(pos_rad, 0.0)
|
||||
else:
|
||||
dt = t_now - self._prev_t
|
||||
dt = max(dt, 1e-4) # Avoid division by zero
|
||||
vel_rad_raw = {m: (pos_rad[m] - self._prev_pos_rad[m]) / dt for m in pos_rad}
|
||||
|
||||
# Store raw velocity history
|
||||
for m in vel_rad_raw:
|
||||
self._vel_raw_history[m].append(vel_rad_raw[m])
|
||||
|
||||
# Apply Butterworth low-pass filter to velocity
|
||||
vel_rad = {}
|
||||
for m in pos_rad:
|
||||
if len(self._vel_raw_history[m]) >= 10:
|
||||
vel_raw_array = np.array(list(self._vel_raw_history[m]))
|
||||
|
||||
# Apply Butterworth filter
|
||||
vel_filtered = lfilter(self._b, self._a, vel_raw_array)
|
||||
vel_rad[m] = vel_filtered[-1]
|
||||
else:
|
||||
vel_rad[m] = vel_rad_raw[m]
|
||||
|
||||
# Calculate acceleration from filtered velocity
|
||||
acc_rad = {}
|
||||
if self._prev_vel_rad is None or self._prev_t is None:
|
||||
acc_rad = dict.fromkeys(pos_rad, 0.0)
|
||||
else:
|
||||
dt = t_now - self._prev_t
|
||||
dt = max(dt, 1e-4) # Avoid division by zero
|
||||
acc_rad = {m: (vel_rad[m] - self._prev_vel_rad[m]) / dt for m in vel_rad}
|
||||
|
||||
self._prev_pos_rad = pos_rad.copy()
|
||||
self._prev_vel_rad = vel_rad.copy()
|
||||
self._prev_t = t_now
|
||||
|
||||
# Measured torque (Nm)
|
||||
cur_raw = self.bus.sync_read("Present_Current", normalize=False, num_retry=5)
|
||||
tau_meas = self._current_to_torque_nm(cur_raw)
|
||||
|
||||
# Compute reaction torques using model-based approach
|
||||
tau_reaction = self._compute_model_based_disturbance(pos_rad, vel_rad, acc_rad, tau_meas)
|
||||
|
||||
obs_dict = {}
|
||||
obs_dict |= {f"{m}.pos": pos_rad[m] for m in self.bus.motors}
|
||||
obs_dict |= {f"{m}.vel": vel_rad[m] for m in self.bus.motors}
|
||||
obs_dict |= {f"{m}.acc": acc_rad[m] for m in self.bus.motors}
|
||||
obs_dict |= {f"{m}.effort": tau_reaction[m] for m in self.bus.motors}
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Store observation for feedforward compensation
|
||||
self._last_observation = obs_dict.copy()
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Command arm to move to a target torque for a joint.
|
||||
|
||||
Raises:
|
||||
RobotDeviceNotConnectedError: if robot is not connected.
|
||||
|
||||
Returns:
|
||||
the action sent to the motors.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Extract torque commands
|
||||
tau_cmd_nm = {k.removesuffix(".effort"): float(v) for k, v in action.items() if k.endswith(".effort")}
|
||||
if not tau_cmd_nm:
|
||||
return action
|
||||
|
||||
# Add feedforward compensation if we have a last observation
|
||||
if self._last_observation is not None:
|
||||
# Extract position, velocity, acceleration from last observation
|
||||
pos_rad = {m: self._last_observation[f"{m}.pos"] for m in self.bus.motors}
|
||||
vel_rad = {m: self._last_observation[f"{m}.vel"] for m in self.bus.motors}
|
||||
acc_rad = {m: self._last_observation[f"{m}.acc"] for m in self.bus.motors}
|
||||
|
||||
# Compute feedforward terms
|
||||
tau_gravity = self._gravity_from_q(pos_rad)
|
||||
tau_inertia = self._inertia_from_q_dq(pos_rad, vel_rad, acc_rad)
|
||||
|
||||
# Add feedforward compensation to commanded torques
|
||||
for motor in tau_cmd_nm:
|
||||
# Add gravity compensation
|
||||
tau_cmd_nm[motor] += tau_gravity[motor]
|
||||
|
||||
# Add inertia compensation
|
||||
tau_cmd_nm[motor] += tau_inertia[motor]
|
||||
|
||||
# Add friction compensation
|
||||
omega = vel_rad[motor]
|
||||
tau_friction = self._FRICTION_VISCOUS[motor] * omega + self._FRICTION_COULOMB[motor] * (
|
||||
1.0 if omega > 0.01 else -1.0 if omega < -0.01 else 0.0
|
||||
)
|
||||
tau_friction = -tau_friction # Apply torque sign correction
|
||||
tau_cmd_nm[motor] += tau_friction
|
||||
|
||||
inv_coef = 1.0 / (self._CURRENT_STEP_A * self._KT_NM_PER_AMP)
|
||||
max_cnt = int(round(self._MAX_CURRENT_A / self._CURRENT_STEP_A))
|
||||
counts = {}
|
||||
for joint, τ in tau_cmd_nm.items():
|
||||
cnt = τ * self.torque_sign[joint] * inv_coef # flip SIGN
|
||||
cnt = max(-max_cnt, min(max_cnt, cnt))
|
||||
counts[joint] = int(round(cnt))
|
||||
|
||||
self.bus.sync_write("Target_Torque", counts, normalize=False, num_retry=2)
|
||||
self._last_cmd_nm = tau_cmd_nm
|
||||
return action
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -37,6 +37,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .so101_follower import SO101Follower
|
||||
|
||||
return SO101Follower(config)
|
||||
elif config.type == "so101_follower_t":
|
||||
from .so101_follower_torque import SO101FollowerT
|
||||
|
||||
return SO101FollowerT(config)
|
||||
elif config.type == "lekiwi":
|
||||
from .lekiwi import LeKiwi
|
||||
|
||||
@@ -49,6 +53,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .viperx import ViperX
|
||||
|
||||
return ViperX(config)
|
||||
elif config.type == "hope_jr_hand":
|
||||
from .hope_jr import HopeJrHand
|
||||
|
||||
return HopeJrHand(config)
|
||||
elif config.type == "hope_jr_arm":
|
||||
from .hope_jr import HopeJrArm
|
||||
|
||||
return HopeJrArm(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
|
||||
@@ -317,7 +317,7 @@ def act_with_policy(
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
@@ -642,9 +642,29 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
state_dicts = bytes_to_state_dict(bytes_state_dict)
|
||||
|
||||
# TODO: check encoder parameter synchronization possible issues:
|
||||
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
|
||||
# instead of the updated encoder params from critic (which is optimized separately)
|
||||
# 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params
|
||||
# 3. Need to handle encoder params correctly for both actor and discrete_critic
|
||||
# Potential fixes:
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
#################################################
|
||||
|
||||
@@ -1109,8 +1109,18 @@ def check_nan_in_transition(
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
||||
state_bytes = state_to_bytes(state_dict)
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
|
||||
197
src/lerobot/scripts/server/configs.py
Normal file
197
src/lerobot/scripts/server/configs.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.scripts.server.constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
# Aggregate function registry for CLI usage
|
||||
AGGREGATE_FUNCTIONS = {
|
||||
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
|
||||
"latest_only": lambda old, new: new,
|
||||
"average": lambda old, new: 0.5 * old + 0.5 * new,
|
||||
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
|
||||
}
|
||||
|
||||
|
||||
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Get aggregate function by name from registry."""
|
||||
if name not in AGGREGATE_FUNCTIONS:
|
||||
available = list(AGGREGATE_FUNCTIONS.keys())
|
||||
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
|
||||
return AGGREGATE_FUNCTIONS[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerConfig:
|
||||
"""Configuration for PolicyServer.
|
||||
|
||||
This class defines all configurable parameters for the PolicyServer,
|
||||
including networking settings and action chunking specifications.
|
||||
"""
|
||||
|
||||
# Networking configuration
|
||||
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
|
||||
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
|
||||
|
||||
# Timing configuration
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
inference_latency: float = field(
|
||||
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
|
||||
)
|
||||
|
||||
obs_queue_timeout: float = field(
|
||||
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if self.port < 1 or self.port > 65535:
|
||||
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
|
||||
|
||||
if self.environment_dt <= 0:
|
||||
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
|
||||
|
||||
if self.inference_latency < 0:
|
||||
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
|
||||
|
||||
if self.obs_queue_timeout < 0:
|
||||
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
|
||||
"""Create a PolicyServerConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"fps": self.fps,
|
||||
"environment_dt": self.environment_dt,
|
||||
"inference_latency": self.inference_latency,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotClientConfig:
|
||||
"""Configuration for RobotClient.
|
||||
|
||||
This class defines all configurable parameters for the RobotClient,
|
||||
including network connection, policy settings, and control behavior.
|
||||
"""
|
||||
|
||||
# Policy configuration
|
||||
policy_type: str = field(metadata={"help": "Type of policy to use"})
|
||||
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
|
||||
|
||||
# Robot configuration (for CLI usage - robot instance will be created from this)
|
||||
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
|
||||
|
||||
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
|
||||
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
|
||||
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
|
||||
|
||||
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
|
||||
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
|
||||
|
||||
# Network configuration
|
||||
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
|
||||
# Aggregate function configuration (CLI-compatible)
|
||||
aggregate_fn_name: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
|
||||
)
|
||||
|
||||
# Debug configuration
|
||||
debug_visualize_queue_size: bool = field(
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
# Verification configuration
|
||||
verify_robot_cameras: bool = field(
|
||||
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.server_address:
|
||||
raise ValueError("server_address cannot be empty")
|
||||
|
||||
if not self.policy_type:
|
||||
raise ValueError("policy_type cannot be empty")
|
||||
|
||||
if not self.pretrained_name_or_path:
|
||||
raise ValueError("pretrained_name_or_path cannot be empty")
|
||||
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
|
||||
if self.actions_per_chunk <= 0:
|
||||
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
|
||||
|
||||
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
|
||||
"""Create a RobotClientConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
"task": self.task,
|
||||
"debug_visualize_queue_size": self.debug_visualize_queue_size,
|
||||
"aggregate_fn_name": self.aggregate_fn_name,
|
||||
}
|
||||
29
src/lerobot/scripts/server/constants.py
Normal file
29
src/lerobot/scripts/server/constants.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
|
||||
"""Server side: Running inference on (at most) 1/fps"""
|
||||
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
|
||||
"""Server side: Timeout for observation queue in seconds"""
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
386
src/lerobot/scripts/server/helpers.py
Normal file
386
src/lerobot/scripts/server/helpers.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.transport import async_inference_pb2
|
||||
from lerobot.transport.utils import bytes_buffer_size
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
ActionChunk = torch.Tensor
|
||||
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation, ready for policy inference (image keys resized)
|
||||
Observation = dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.plot(range(len(action_queue_size)), action_queue_size)
|
||||
plt.show()
|
||||
|
||||
|
||||
def validate_robot_cameras_for_policy(
|
||||
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
|
||||
) -> None:
|
||||
image_keys = list(filter(is_image_key, lerobot_observation_features))
|
||||
assert set(image_keys) == set(policy_image_features.keys()), (
|
||||
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
||||
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
||||
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
||||
image = image.permute(2, 0, 1)
|
||||
dims = (resize_dims[1], resize_dims[2])
|
||||
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
||||
image_batched = image.unsqueeze(0)
|
||||
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
||||
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
||||
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
device: str,
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
||||
for k, v in observation.items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0).to(device)
|
||||
else:
|
||||
observation[k] = v.to(device)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def extract_state_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the state from a raw observation."""
|
||||
state = torch.tensor(lerobot_obs[OBS_STATE])
|
||||
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def extract_images_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
camera_key: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Extract the images from a raw observation."""
|
||||
return torch.tensor(lerobot_obs[camera_key])
|
||||
|
||||
|
||||
def make_lerobot_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
||||
policy_image_features)."""
|
||||
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
||||
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
||||
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
||||
|
||||
# 2. Greps all observation.images.<> keys
|
||||
image_keys = list(filter(is_image_key, lerobot_obs))
|
||||
# state's shape is expected as (B, state_dim)
|
||||
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
||||
image_dict = {
|
||||
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
||||
}
|
||||
|
||||
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
||||
# This reduces the resolution of the images
|
||||
image_dict = {
|
||||
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
||||
for key in image_keys
|
||||
}
|
||||
|
||||
if "task" in robot_obs:
|
||||
state_dict["task"] = robot_obs["task"]
|
||||
|
||||
return {**state_dict, **image_dict}
|
||||
|
||||
|
||||
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
"""
|
||||
Get a logger using the standardized logging setup from utils.py.
|
||||
|
||||
Args:
|
||||
name: Logger name (e.g., 'policy_server', 'robot_client')
|
||||
log_to_file: Whether to also log to a file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if logging to file
|
||||
if log_to_file:
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
||||
else:
|
||||
log_file = None
|
||||
|
||||
# Initialize the standardized logging
|
||||
init_logging(log_file=log_file, display_pid=False)
|
||||
|
||||
# Return a named logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
timestep: int
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedAction(TimedData):
|
||||
action: Action
|
||||
|
||||
def get_action(self):
|
||||
return self.action
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedObservation(TimedData):
|
||||
observation: RawObservation
|
||||
must_go: bool = False
|
||||
|
||||
def get_observation(self):
|
||||
return self.observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FPSTracker:
|
||||
"""Utility class to track FPS metrics over time."""
|
||||
|
||||
target_fps: float
|
||||
first_timestamp: float = None
|
||||
total_obs_count: int = 0
|
||||
|
||||
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
||||
"""Calculate average FPS vs target"""
|
||||
self.total_obs_count += 1
|
||||
|
||||
# Initialize first observation time
|
||||
if self.first_timestamp is None:
|
||||
self.first_timestamp = current_timestamp
|
||||
|
||||
# Calculate overall average FPS (since start)
|
||||
total_duration = current_timestamp - self.first_timestamp
|
||||
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
||||
|
||||
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the FPS tracker state"""
|
||||
self.first_timestamp = None
|
||||
self.total_obs_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePolicyConfig:
|
||||
policy_type: str
|
||||
pretrained_name_or_path: str
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
||||
|
||||
|
||||
def observations_similar(
|
||||
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
||||
) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
||||
observations as the difference in joint-space between the two observations.
|
||||
|
||||
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
||||
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
||||
to surpass this joint-space similarity check.
|
||||
"""
|
||||
obs1_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
||||
)
|
||||
obs2_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
|
||||
|
||||
def send_bytes_in_chunks(
|
||||
buffer: bytes,
|
||||
message_class: Any,
|
||||
log_prefix: str = "",
|
||||
silent: bool = True,
|
||||
chunk_size: int = 3 * 1024 * 1024,
|
||||
):
|
||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
|
||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
|
||||
# chunk size as I am using it to send image observations.
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + chunk_size >= size_in_bytes:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(
|
||||
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
|
||||
): # type: ignore
|
||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
|
||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
|
||||
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logger.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logger.debug(f"{log_prefix} Received item")
|
||||
if not continue_receiving.is_set():
|
||||
logger.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logger.debug(f"{log_prefix} Received data at step 0")
|
||||
|
||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logger.debug(f"{log_prefix} Received data at step {step}")
|
||||
|
||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
complete_bytes = bytes_buffer.getvalue()
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
|
||||
logger.debug(f"{log_prefix} Queue updated")
|
||||
return complete_bytes
|
||||
|
||||
else:
|
||||
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
|
||||
raise ValueError(f"Received unknown transfer state {item.transfer_state}")
|
||||
403
src/lerobot/scripts/server/policy_server.py
Normal file
403
src/lerobot/scripts/server/policy_server.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
--inference_latency=0.033 \
|
||||
--obs_queue_timeout=1
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
receive_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self._running_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
self._predicted_timesteps_lock = threading.Lock()
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
self.last_processed_obs = None
|
||||
|
||||
# Attributes will be set by SendPolicyInstructions
|
||||
self.device = None
|
||||
self.policy_type = None
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
return self.policy.config.image_features
|
||||
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self._running_event.clear()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self._running_event.set()
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
|
||||
if not isinstance(policy_specs, RemotePolicyConfig):
|
||||
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
|
||||
|
||||
if policy_specs.policy_type not in SUPPORTED_POLICIES:
|
||||
raise ValueError(
|
||||
f"Policy type {policy_specs.policy_type} not supported. "
|
||||
f"Supported policies: {SUPPORTED_POLICIES}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Receiving policy instructions from {client_id} | "
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
self.lerobot_features = policy_specs.lerobot_features
|
||||
self.actions_per_chunk = policy_specs.actions_per_chunk
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, self._running_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.info(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Deserialization time: {deserialize_time:.6f}s"
|
||||
)
|
||||
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
chunk, containing multiple actions."""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
getactions_starts = time.perf_counter()
|
||||
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
actions_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = async_inference_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.2f}s |"
|
||||
f"Serialize time: {serialize_time:.2f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.2f}s"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
|
||||
) # sleep controls inference latency
|
||||
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
with self._predicted_timesteps_lock:
|
||||
predicted_timesteps = self._predicted_timesteps
|
||||
|
||||
if obs.get_timestep() in predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if (
|
||||
obs.must_go
|
||||
or self.last_processed_obs is None
|
||||
or self._obs_sanity_checks(obs, self.last_processed_obs)
|
||||
):
|
||||
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
|
||||
self.logger.debug(
|
||||
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
|
||||
)
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
|
||||
"""
|
||||
Prepare observation, ready for policy inference.
|
||||
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
|
||||
client and then convert them to float32 [0,1] images here, before running inference.
|
||||
"""
|
||||
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
self.device,
|
||||
)
|
||||
# processed Observation - right keys, right dtype, right image shape
|
||||
|
||||
return observation
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
if chunk.ndim != 3:
|
||||
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
|
||||
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation"""
|
||||
inference_starts = time.perf_counter()
|
||||
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.perf_counter()
|
||||
observation = self._prepare_observation(observation_t)
|
||||
preprocessing_time = time.perf_counter() - start_time
|
||||
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
|
||||
"""2. Get action chunk"""
|
||||
start_time = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
"""3. Post-inference processing"""
|
||||
start_time = time.perf_counter()
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu().squeeze(0)
|
||||
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocessing_time = time.perf_counter() - start_time
|
||||
inference_stops = time.perf_counter()
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} |"
|
||||
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
|
||||
)
|
||||
|
||||
# full-process latency breakdown for debugging purposes
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
|
||||
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
|
||||
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
|
||||
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def serve(cfg: PolicyServerConfig):
|
||||
"""Start the PolicyServer with the given configuration.
|
||||
|
||||
Args:
|
||||
config: PolicyServerConfig instance. If None, uses default configuration.
|
||||
"""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
server.start()
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
509
src/lerobot/scripts/server/robot_client.py
Normal file
509
src/lerobot/scripts/server/robot_client.py
Normal file
@@ -0,0 +1,509 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--task="dummy" \
|
||||
--server_address=127.0.0.1:8080 \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
--debug_visualize_queue_size=True
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
|
||||
from lerobot.scripts.server.helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RawObservation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
send_bytes_in_chunks,
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: RobotClientConfig):
|
||||
"""Initialize RobotClient with unified configuration.
|
||||
|
||||
Args:
|
||||
config: RobotClientConfig containing all configuration parameters
|
||||
"""
|
||||
# Store configuration
|
||||
self.config = config
|
||||
self.robot = make_robot_from_config(config.robot)
|
||||
self.robot.connect()
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
if config.verify_robot_cameras:
|
||||
# Load policy config for validation
|
||||
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
|
||||
policy_image_features = policy_config.image_features
|
||||
|
||||
# The cameras specified for inference must match the one supported by the policy chosen
|
||||
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
self.policy_config = RemotePolicyConfig(
|
||||
config.policy_type,
|
||||
config.pretrained_name_or_path,
|
||||
lerobot_features,
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(self.server_address)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self._running_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = config.chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock() # Protect queue operations
|
||||
self.action_queue_size = []
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
|
||||
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
# Use an event for thread-safe coordination
|
||||
self.must_go = threading.Event()
|
||||
self.must_go.set() # Initially set - observations qualify for direct processing
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self._running_event.set()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self._running_event.clear()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.debug("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
|
||||
|
||||
if not isinstance(obs, TimedObservation):
|
||||
raise ValueError("Input observation needs to be a TimedObservation!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
|
||||
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
async_inference_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.info(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
with self.action_queue_lock:
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
if aggregate_fn is None:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
future_action_queue = Queue()
|
||||
with self.action_queue_lock:
|
||||
internal_queue = self.action_queue.queue
|
||||
|
||||
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
# New action is older than the latest action in the queue, skip it
|
||||
if new_action.get_timestep() <= latest_action:
|
||||
continue
|
||||
|
||||
# If the new action's timestep is not in the current action queue, add it directly
|
||||
elif new_action.get_timestep() not in current_action_queue:
|
||||
future_action_queue.put(new_action)
|
||||
continue
|
||||
|
||||
# If the new action's timestep is in the current action queue, aggregate it
|
||||
# TODO: There is probably a way to do this with broadcasting of the two action tensors
|
||||
future_action_queue.put(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
timestep=new_action.get_timestep(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with self.action_queue_lock:
|
||||
self.action_queue = future_action_queue
|
||||
|
||||
def receive_actions(self, verbose: bool = False):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.perf_counter()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0 and verbose:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{latest_action} | "
|
||||
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
|
||||
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.perf_counter()
|
||||
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
|
||||
queue_update_time = time.perf_counter() - start_time
|
||||
|
||||
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
|
||||
|
||||
if verbose:
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.info(
|
||||
f"Latest action: {latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
|
||||
def actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
with self.action_queue_lock:
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
get_start = time.perf_counter()
|
||||
with self.action_queue_lock:
|
||||
self.action_queue_size.append(self.action_queue.qsize())
|
||||
# Get action from queue
|
||||
timed_action = self.action_queue.get_nowait()
|
||||
get_end = time.perf_counter() - get_start
|
||||
|
||||
_performed_action = self.robot.send_action(
|
||||
self._action_tensor_to_action_dict(timed_action.get_action())
|
||||
)
|
||||
with self.latest_action_lock:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
if verbose:
|
||||
with self.action_queue_lock:
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
return _performed_action
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.perf_counter()
|
||||
|
||||
raw_observation: RawObservation = self.robot.get_observation()
|
||||
raw_observation["task"] = task
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), # need time.time() to compare timestamps across client and server
|
||||
observation=raw_observation,
|
||||
timestep=max(latest_action, 0),
|
||||
)
|
||||
|
||||
obs_capture_time = time.perf_counter() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
with self.action_queue_lock:
|
||||
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
_ = self.send_observation(observation)
|
||||
|
||||
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go event will be set again after receiving actions
|
||||
self.must_go.clear()
|
||||
|
||||
if verbose:
|
||||
# Calculate comprehensive FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
|
||||
|
||||
self.logger.info(
|
||||
f"Obs #{observation.get_timestep()} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
|
||||
f"Target: {fps_metrics['target_fps']:.2f}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
|
||||
)
|
||||
|
||||
return raw_observation
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
_performed_action = None
|
||||
_captured_observation = None
|
||||
|
||||
while self.running:
|
||||
control_loop_start = time.perf_counter()
|
||||
"""Control loop: (1) Performing actions, when available"""
|
||||
if self.actions_available():
|
||||
_performed_action = self.control_loop_action(verbose)
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
return _captured_observation, _performed_action
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
if client.start():
|
||||
client.logger.info("Starting action receiver thread...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
|
||||
# Start action receiver thread
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# The main thread runs the control loop
|
||||
client.control_loop(task=cfg.task)
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
if cfg.debug_visualize_queue_size:
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async_client() # run the client
|
||||
@@ -16,7 +16,6 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
@@ -30,7 +29,6 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.datasets.utils_must import multidataset_collate_fn
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
@@ -175,23 +173,14 @@ def train(cfg: TrainPipelineConfig):
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
|
||||
keys_to_max_dim = {
|
||||
"action": (32,),
|
||||
"observation.state": (32,),
|
||||
"observation.image": (3, 1080, 1920),
|
||||
"observation.image2": (3, 1080, 1920),
|
||||
}
|
||||
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
@@ -218,7 +207,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
|
||||
@@ -35,6 +35,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
so101_follower_torque,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
@@ -52,6 +53,7 @@ COMPATIBLE_DEVICES = [
|
||||
"so101_follower",
|
||||
"so101_leader",
|
||||
"lekiwi",
|
||||
"so101_follower_t",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
@@ -52,6 +53,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
gamepad,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
|
||||
4
src/lerobot/teleoperators/homunculus/__init__.py
Normal file
4
src/lerobot/teleoperators/homunculus/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig
|
||||
from .homunculus_arm import HomunculusArm
|
||||
from .homunculus_glove import HomunculusGlove
|
||||
from .joints_translation import homunculus_glove_to_hope_jr_hand
|
||||
38
src/lerobot/teleoperators/homunculus/config_homunculus.py
Normal file
38
src/lerobot/teleoperators/homunculus/config_homunculus.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("homunculus_glove")
|
||||
@dataclass
|
||||
class HomunculusGloveConfig(TeleoperatorConfig):
|
||||
port: str # Port to connect to the glove
|
||||
side: str # "left" / "right"
|
||||
baud_rate: int = 115_200
|
||||
|
||||
def __post_init__(self):
|
||||
if self.side not in ["right", "left"]:
|
||||
raise ValueError(self.side)
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("homunculus_arm")
|
||||
@dataclass
|
||||
class HomunculusArmConfig(TeleoperatorConfig):
|
||||
port: str # Port to connect to the arm
|
||||
baud_rate: int = 115_200
|
||||
310
src/lerobot/teleoperators/homunculus/homunculus_arm.py
Normal file
310
src/lerobot/teleoperators/homunculus/homunculus_arm.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from pprint import pformat
|
||||
from typing import Deque, Dict, Optional
|
||||
|
||||
import serial
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_homunculus import HomunculusArmConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HomunculusArm(Teleoperator):
|
||||
"""
|
||||
Homunculus Arm designed by Hugging Face.
|
||||
"""
|
||||
|
||||
config_class = HomunculusArmConfig
|
||||
name = "homunculus_arm"
|
||||
|
||||
def __init__(self, config: HomunculusArmConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
|
||||
self.serial_lock = threading.Lock()
|
||||
|
||||
self.joints = {
|
||||
"shoulder_pitch": MotorNormMode.RANGE_M100_100,
|
||||
"shoulder_yaw": MotorNormMode.RANGE_M100_100,
|
||||
"shoulder_roll": MotorNormMode.RANGE_M100_100,
|
||||
"elbow_flex": MotorNormMode.RANGE_M100_100,
|
||||
"wrist_roll": MotorNormMode.RANGE_M100_100,
|
||||
"wrist_yaw": MotorNormMode.RANGE_M100_100,
|
||||
"wrist_pitch": MotorNormMode.RANGE_M100_100,
|
||||
}
|
||||
n = 50
|
||||
# EMA parameters ---------------------------------------------------
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
# one deque *per joint* so we can inspect raw history if needed
|
||||
self._buffers: Dict[str, Deque[int]] = {
|
||||
joint: deque(maxlen=n)
|
||||
for joint in (
|
||||
"shoulder_pitch",
|
||||
"shoulder_yaw",
|
||||
"shoulder_roll",
|
||||
"elbow_flex",
|
||||
"wrist_roll",
|
||||
"wrist_yaw",
|
||||
"wrist_pitch",
|
||||
)
|
||||
}
|
||||
# running EMA value per joint – lazily initialised on first read
|
||||
self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers)
|
||||
|
||||
self._state: dict[str, float] | None = None
|
||||
self.new_state_event = threading.Event()
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
|
||||
self.state_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return {f"{joint}.pos": float for joint in self.joints}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
with self.serial_lock:
|
||||
return self.serial.is_open and self.thread.is_alive()
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if not self.serial.is_open:
|
||||
self.serial.open()
|
||||
self.thread.start()
|
||||
|
||||
# wait for the thread to ramp up & 1st state to be ready
|
||||
if not self.new_state_event.wait(timeout=2):
|
||||
raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
|
||||
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.calibration_fpath.is_file()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"\nMove all joints through their entire range of motion."
|
||||
"\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
range_mins, range_maxes = self._record_ranges_of_motion()
|
||||
|
||||
self.calibration = {}
|
||||
for id_, joint in enumerate(self.joints):
|
||||
self.calibration[joint] = MotorCalibration(
|
||||
id=id_,
|
||||
drive_mode=0,
|
||||
homing_offset=0,
|
||||
range_min=range_mins[joint],
|
||||
range_max=range_maxes[joint],
|
||||
)
|
||||
|
||||
self._save_calibration()
|
||||
print("Calibration saved to", self.calibration_fpath)
|
||||
|
||||
# TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code.
|
||||
def _record_ranges_of_motion(
|
||||
self, joints: list[str] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[str, int], dict[str, int]]:
|
||||
"""Interactively record the min/max encoder values of each joint.
|
||||
|
||||
Move the joints while the method streams live positions. Press :kbd:`Enter` to finish.
|
||||
|
||||
Args:
|
||||
joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`).
|
||||
display_values (bool, optional): When `True` (default) a live table is printed to the console.
|
||||
|
||||
Raises:
|
||||
TypeError: `joints` is not `None` or a list.
|
||||
ValueError: any joint's recorded min and max are the same.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values
|
||||
observed for each joint.
|
||||
"""
|
||||
if joints is None:
|
||||
joints = list(self.joints)
|
||||
elif not isinstance(joints, list):
|
||||
raise TypeError(joints)
|
||||
|
||||
display_len = max(len(key) for key in joints)
|
||||
|
||||
start_positions = self._read(joints, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self._read(joints, normalize=False)
|
||||
mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()}
|
||||
maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for joint in joints:
|
||||
print(
|
||||
f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}"
|
||||
)
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(joints) + 3)
|
||||
|
||||
same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
# TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code.
|
||||
def _normalize(self, values: dict[str, int]) -> dict[str, float]:
|
||||
if not self.calibration:
|
||||
raise RuntimeError(f"{self} has no calibration registered.")
|
||||
|
||||
normalized_values = {}
|
||||
for joint, val in values.items():
|
||||
min_ = self.calibration[joint].range_min
|
||||
max_ = self.calibration[joint].range_max
|
||||
drive_mode = self.calibration[joint].drive_mode
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
|
||||
if self.joints[joint] is MotorNormMode.RANGE_M100_100:
|
||||
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
||||
normalized_values[joint] = -norm if drive_mode else norm
|
||||
elif self.joints[joint] is MotorNormMode.RANGE_0_100:
|
||||
norm = ((bounded_val - min_) / (max_ - min_)) * 100
|
||||
normalized_values[joint] = 100 - norm if drive_mode else norm
|
||||
|
||||
return normalized_values
|
||||
|
||||
def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, float]:
|
||||
"""Update buffers & running EMA values; return smoothed dict."""
|
||||
smoothed: Dict[str, float] = {}
|
||||
for joint, value in raw.items():
|
||||
# maintain raw history
|
||||
self._buffers[joint].append(value)
|
||||
|
||||
# initialise on first run
|
||||
if self._ema[joint] is None:
|
||||
self._ema[joint] = float(value)
|
||||
else:
|
||||
self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint]
|
||||
|
||||
smoothed[joint] = self._ema[joint]
|
||||
return smoothed
|
||||
|
||||
def _read(
|
||||
self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
|
||||
) -> dict[str, int | float]:
|
||||
"""
|
||||
Return the most recent (single) values from self.last_d,
|
||||
optionally applying calibration.
|
||||
"""
|
||||
if not self.new_state_event.wait(timeout=timeout):
|
||||
raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
|
||||
|
||||
with self.state_lock:
|
||||
state = self._state
|
||||
|
||||
self.new_state_event.clear()
|
||||
|
||||
if state is None:
|
||||
raise RuntimeError(f"{self} Internal error: Event set but no state available.")
|
||||
|
||||
if joints is not None:
|
||||
state = {k: v for k, v in state.items() if k in joints}
|
||||
|
||||
if normalize:
|
||||
state = self._normalize(state)
|
||||
|
||||
state = self._apply_ema(state)
|
||||
|
||||
return state
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Continuously read from the serial buffer in its own thread and sends values to the main thread through
|
||||
a queue.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
raw_values = None
|
||||
with self.serial_lock:
|
||||
if self.serial.in_waiting > 0:
|
||||
self.serial.flush()
|
||||
raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
|
||||
if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values
|
||||
continue
|
||||
|
||||
joint_angles = {
|
||||
"shoulder_pitch": int(raw_values[19]),
|
||||
"shoulder_yaw": int(raw_values[18]),
|
||||
"shoulder_roll": int(raw_values[20]),
|
||||
"elbow_flex": int(raw_values[17]),
|
||||
"wrist_roll": int(raw_values[16]),
|
||||
"wrist_yaw": int(raw_values[1]),
|
||||
"wrist_pitch": int(raw_values[0]),
|
||||
}
|
||||
|
||||
with self.state_lock:
|
||||
self._state = joint_angles
|
||||
self.new_state_event.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
joint_positions = self._read()
|
||||
return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_event.set()
|
||||
self.thread.join(timeout=1)
|
||||
self.serial.close()
|
||||
logger.info(f"{self} disconnected.")
|
||||
338
src/lerobot/teleoperators/homunculus/homunculus_glove.py
Normal file
338
src/lerobot/teleoperators/homunculus/homunculus_glove.py
Normal file
@@ -0,0 +1,338 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from pprint import pformat
|
||||
from typing import Deque, Dict, Optional
|
||||
|
||||
import serial
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.motors.motors_bus import MotorNormMode
|
||||
from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_homunculus import HomunculusGloveConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LEFT_HAND_INVERSIONS = [
|
||||
"thumb_cmc",
|
||||
"index_dip",
|
||||
"middle_mcp_abduction",
|
||||
"middle_dip",
|
||||
"pinky_mcp_abduction",
|
||||
"pinky_dip",
|
||||
]
|
||||
|
||||
RIGHT_HAND_INVERSIONS = [
|
||||
"thumb_mcp",
|
||||
"thumb_cmc",
|
||||
"thumb_pip",
|
||||
"thumb_dip",
|
||||
"index_mcp_abduction",
|
||||
# "index_dip",
|
||||
"middle_mcp_abduction",
|
||||
# "middle_dip",
|
||||
"ring_mcp_abduction",
|
||||
"ring_mcp_flexion",
|
||||
# "ring_dip",
|
||||
"pinky_mcp_abduction",
|
||||
]
|
||||
|
||||
|
||||
class HomunculusGlove(Teleoperator):
|
||||
"""
|
||||
Homunculus Glove designed by NepYope & Hugging Face.
|
||||
"""
|
||||
|
||||
config_class = HomunculusGloveConfig
|
||||
name = "homunculus_glove"
|
||||
|
||||
def __init__(self, config: HomunculusGloveConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
|
||||
self.serial_lock = threading.Lock()
|
||||
|
||||
self.joints = {
|
||||
"thumb_cmc": MotorNormMode.RANGE_0_100,
|
||||
"thumb_mcp": MotorNormMode.RANGE_0_100,
|
||||
"thumb_pip": MotorNormMode.RANGE_0_100,
|
||||
"thumb_dip": MotorNormMode.RANGE_0_100,
|
||||
"index_mcp_abduction": MotorNormMode.RANGE_M100_100,
|
||||
"index_mcp_flexion": MotorNormMode.RANGE_0_100,
|
||||
"index_dip": MotorNormMode.RANGE_0_100,
|
||||
"middle_mcp_abduction": MotorNormMode.RANGE_M100_100,
|
||||
"middle_mcp_flexion": MotorNormMode.RANGE_0_100,
|
||||
"middle_dip": MotorNormMode.RANGE_0_100,
|
||||
"ring_mcp_abduction": MotorNormMode.RANGE_M100_100,
|
||||
"ring_mcp_flexion": MotorNormMode.RANGE_0_100,
|
||||
"ring_dip": MotorNormMode.RANGE_0_100,
|
||||
"pinky_mcp_abduction": MotorNormMode.RANGE_M100_100,
|
||||
"pinky_mcp_flexion": MotorNormMode.RANGE_0_100,
|
||||
"pinky_dip": MotorNormMode.RANGE_0_100,
|
||||
}
|
||||
self.inverted_joints = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS
|
||||
|
||||
n = 10
|
||||
# EMA parameters ---------------------------------------------------
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
# one deque *per joint* so we can inspect raw history if needed
|
||||
self._buffers: Dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
|
||||
# running EMA value per joint – lazily initialised on first read
|
||||
self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers)
|
||||
|
||||
self._state: dict[str, float] | None = None
|
||||
self.new_state_event = threading.Event()
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
|
||||
self.state_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return {f"{joint}.pos": float for joint in self.joints}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
with self.serial_lock:
|
||||
return self.serial.is_open and self.thread.is_alive()
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if not self.serial.is_open:
|
||||
self.serial.open()
|
||||
self.thread.start()
|
||||
|
||||
# wait for the thread to ramp up & 1st state to be ready
|
||||
if not self.new_state_event.wait(timeout=2):
|
||||
raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
|
||||
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.calibration_fpath.is_file()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
range_mins, range_maxes = {}, {}
|
||||
for finger in ["thumb", "index", "middle", "ring", "pinky"]:
|
||||
print(
|
||||
f"\nMove {finger} through its entire range of motion."
|
||||
"\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
finger_joints = [joint for joint in self.joints if joint.startswith(finger)]
|
||||
finger_mins, finger_maxes = self._record_ranges_of_motion(finger_joints)
|
||||
range_mins.update(finger_mins)
|
||||
range_maxes.update(finger_maxes)
|
||||
|
||||
self.calibration = {}
|
||||
for id_, joint in enumerate(self.joints):
|
||||
self.calibration[joint] = MotorCalibration(
|
||||
id=id_,
|
||||
drive_mode=1 if joint in self.inverted_joints else 0,
|
||||
homing_offset=0,
|
||||
range_min=range_mins[joint],
|
||||
range_max=range_maxes[joint],
|
||||
)
|
||||
|
||||
self._save_calibration()
|
||||
print("Calibration saved to", self.calibration_fpath)
|
||||
|
||||
# TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code.
|
||||
def _record_ranges_of_motion(
|
||||
self, joints: list[str] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[str, int], dict[str, int]]:
|
||||
"""Interactively record the min/max encoder values of each joint.
|
||||
|
||||
Move the joints while the method streams live positions. Press :kbd:`Enter` to finish.
|
||||
|
||||
Args:
|
||||
joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`).
|
||||
display_values (bool, optional): When `True` (default) a live table is printed to the console.
|
||||
|
||||
Raises:
|
||||
TypeError: `joints` is not `None` or a list.
|
||||
ValueError: any joint's recorded min and max are the same.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values
|
||||
observed for each joint.
|
||||
"""
|
||||
if joints is None:
|
||||
joints = list(self.joints)
|
||||
elif not isinstance(joints, list):
|
||||
raise TypeError(joints)
|
||||
|
||||
display_len = max(len(key) for key in joints)
|
||||
|
||||
start_positions = self._read(joints, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self._read(joints, normalize=False)
|
||||
mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()}
|
||||
maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for joint in joints:
|
||||
print(
|
||||
f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}"
|
||||
)
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(joints) + 3)
|
||||
|
||||
same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
# TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code.
|
||||
def _normalize(self, values: dict[str, int]) -> dict[str, float]:
|
||||
if not self.calibration:
|
||||
raise RuntimeError(f"{self} has no calibration registered.")
|
||||
|
||||
normalized_values = {}
|
||||
for joint, val in values.items():
|
||||
min_ = self.calibration[joint].range_min
|
||||
max_ = self.calibration[joint].range_max
|
||||
drive_mode = self.calibration[joint].drive_mode
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
|
||||
if self.joints[joint] is MotorNormMode.RANGE_M100_100:
|
||||
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
||||
normalized_values[joint] = -norm if drive_mode else norm
|
||||
elif self.joints[joint] is MotorNormMode.RANGE_0_100:
|
||||
norm = ((bounded_val - min_) / (max_ - min_)) * 100
|
||||
normalized_values[joint] = 100 - norm if drive_mode else norm
|
||||
|
||||
return normalized_values
|
||||
|
||||
def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, int]:
|
||||
"""Update buffers & running EMA values; return smoothed dict as integers."""
|
||||
smoothed: Dict[str, int] = {}
|
||||
for joint, value in raw.items():
|
||||
# maintain raw history
|
||||
self._buffers[joint].append(value)
|
||||
|
||||
# initialise on first run
|
||||
if self._ema[joint] is None:
|
||||
self._ema[joint] = float(value)
|
||||
else:
|
||||
self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint]
|
||||
|
||||
# Convert back to int for compatibility with normalization
|
||||
smoothed[joint] = int(round(self._ema[joint]))
|
||||
return smoothed
|
||||
|
||||
def _read(
|
||||
self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
|
||||
) -> dict[str, int | float]:
|
||||
"""
|
||||
Return the most recent (single) values from self.last_d,
|
||||
optionally applying calibration.
|
||||
"""
|
||||
if not self.new_state_event.wait(timeout=timeout):
|
||||
raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
|
||||
|
||||
with self.state_lock:
|
||||
state = self._state
|
||||
|
||||
self.new_state_event.clear()
|
||||
|
||||
if state is None:
|
||||
raise RuntimeError(f"{self} Internal error: Event set but no state available.")
|
||||
|
||||
if joints is not None:
|
||||
state = {k: v for k, v in state.items() if k in joints}
|
||||
|
||||
# Apply EMA smoothing to raw values first
|
||||
state = self._apply_ema(state)
|
||||
|
||||
# Then normalize if requested
|
||||
if normalize:
|
||||
state = self._normalize(state)
|
||||
|
||||
return state
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Continuously read from the serial buffer in its own thread and sends values to the main thread through
|
||||
a queue.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
positions = None
|
||||
with self.serial_lock:
|
||||
if self.serial.in_waiting > 0:
|
||||
self.serial.flush()
|
||||
positions = self.serial.readline().decode("utf-8").strip().split(" ")
|
||||
if positions is None or len(positions) != len(self.joints):
|
||||
continue
|
||||
|
||||
joint_positions = {joint: int(pos) for joint, pos in zip(self.joints, positions, strict=True)}
|
||||
|
||||
with self.state_lock:
|
||||
self._state = joint_positions
|
||||
self.new_state_event.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
joint_positions = self._read()
|
||||
return homunculus_glove_to_hope_jr_hand(
|
||||
{f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
)
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_event.set()
|
||||
self.thread.join(timeout=1)
|
||||
self.serial.close()
|
||||
logger.info(f"{self} disconnected.")
|
||||
63
src/lerobot/teleoperators/homunculus/joints_translation.py
Normal file
63
src/lerobot/teleoperators/homunculus/joints_translation.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
INDEX_SPLAY = 0.3
|
||||
MIDDLE_SPLAY = 0.3
|
||||
RING_SPLAY = 0.3
|
||||
PINKY_SPLAY = 0.5
|
||||
|
||||
|
||||
def get_ulnar_flexion(flexion: float, abduction: float, splay: float):
|
||||
return -abduction * splay + flexion * (1 - splay)
|
||||
|
||||
|
||||
def get_radial_flexion(flexion: float, abduction: float, splay: float):
|
||||
return abduction * splay + flexion * (1 - splay)
|
||||
|
||||
|
||||
def homunculus_glove_to_hope_jr_hand(glove_action: dict[str, float]) -> dict[str, float]:
|
||||
return {
|
||||
"thumb_cmc.pos": glove_action["thumb_cmc.pos"],
|
||||
"thumb_mcp.pos": glove_action["thumb_mcp.pos"],
|
||||
"thumb_pip.pos": glove_action["thumb_pip.pos"],
|
||||
"thumb_dip.pos": glove_action["thumb_dip.pos"],
|
||||
"index_radial_flexor.pos": get_radial_flexion(
|
||||
glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY
|
||||
),
|
||||
"index_ulnar_flexor.pos": get_ulnar_flexion(
|
||||
glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY
|
||||
),
|
||||
"index_pip_dip.pos": glove_action["index_dip.pos"],
|
||||
"middle_radial_flexor.pos": get_radial_flexion(
|
||||
glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY
|
||||
),
|
||||
"middle_ulnar_flexor.pos": get_ulnar_flexion(
|
||||
glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY
|
||||
),
|
||||
"middle_pip_dip.pos": glove_action["middle_dip.pos"],
|
||||
"ring_radial_flexor.pos": get_radial_flexion(
|
||||
glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY
|
||||
),
|
||||
"ring_ulnar_flexor.pos": get_ulnar_flexion(
|
||||
glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY
|
||||
),
|
||||
"ring_pip_dip.pos": glove_action["ring_dip.pos"],
|
||||
"pinky_radial_flexor.pos": get_radial_flexion(
|
||||
glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY
|
||||
),
|
||||
"pinky_ulnar_flexor.pos": get_ulnar_flexion(
|
||||
glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY
|
||||
),
|
||||
"pinky_pip_dip.pos": glove_action["pinky_dip.pos"],
|
||||
}
|
||||
@@ -33,6 +33,12 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .so101_leader import SO101Leader
|
||||
|
||||
return SO101Leader(config)
|
||||
elif config.type == "so101_follower_t":
|
||||
# For bilateral teleoperation, SO101FollowerT is used as a robot, not a teleoperator
|
||||
# This should be handled in the record.py file instead
|
||||
raise ValueError(
|
||||
"so101_follower_t should be created as a robot instance for bilateral teleoperation, not as a teleoperator"
|
||||
)
|
||||
elif config.type == "stretch3":
|
||||
from .stretch3_gamepad import Stretch3GamePad
|
||||
|
||||
@@ -53,5 +59,13 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .keyboard.teleop_keyboard import KeyboardEndEffectorTeleop
|
||||
|
||||
return KeyboardEndEffectorTeleop(config)
|
||||
elif config.type == "homunculus_glove":
|
||||
from .homunculus import HomunculusGlove
|
||||
|
||||
return HomunculusGlove(config)
|
||||
elif config.type == "homunculus_arm":
|
||||
from .homunculus import HomunculusArm
|
||||
|
||||
return HomunculusArm(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
59
src/lerobot/transport/async_inference.proto
Normal file
59
src/lerobot/transport/async_inference.proto
Normal file
@@ -0,0 +1,59 @@
|
||||
// fmt: off
|
||||
// flake8: noqa
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package async_inference;
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc GetActions(Empty) returns (Actions);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
rpc Stop(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Actions {
|
||||
// sent by remote Policy, to Robot
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
45
src/lerobot/transport/async_inference_pb2.py
Normal file
45
src/lerobot/transport/async_inference_pb2.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: async_inference.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'async_inference.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=190
|
||||
_globals['_TRANSFERSTATE']._serialized_end=286
|
||||
_globals['_OBSERVATION']._serialized_start=42
|
||||
_globals['_OBSERVATION']._serialized_end=125
|
||||
_globals['_ACTIONS']._serialized_start=127
|
||||
_globals['_ACTIONS']._serialized_end=150
|
||||
_globals['_POLICYSETUP']._serialized_start=152
|
||||
_globals['_POLICYSETUP']._serialized_end=179
|
||||
_globals['_EMPTY']._serialized_start=181
|
||||
_globals['_EMPTY']._serialized_end=188
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=289
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=638
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
277
src/lerobot/transport/async_inference_pb2_grpc.py
Normal file
277
src/lerobot/transport/async_inference_pb2_grpc.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from lerobot.transport import async_inference_pb2 as async__inference__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetActions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/GetActions',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Actions.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Stop = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Stop',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Stop(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetActions,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Actions.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Stop': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Stop,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'async_inference.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
async__inference__pb2.Observation.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/GetActions',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Actions.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
async__inference__pb2.PolicySetup.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Stop(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Stop',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -11,11 +11,11 @@
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
|
||||
|
||||
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
|
||||
//
|
||||
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. src/lerobot/transport/services.proto
|
||||
// python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
|
||||
//
|
||||
// The command should be launched from the root of the project.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: src/lerobot/transport/services.proto
|
||||
# source: lerobot/transport/services.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
@@ -14,7 +14,7 @@ _runtime_version.ValidateProtobufRuntimeVersion(
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'src/lerobot/transport/services.proto'
|
||||
'lerobot/transport/services.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
@@ -23,23 +23,23 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.lerobot.transport.services_pb2', _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=302
|
||||
_globals['_TRANSFERSTATE']._serialized_end=398
|
||||
_globals['_TRANSITION']._serialized_start=51
|
||||
_globals['_TRANSITION']._serialized_end=127
|
||||
_globals['_PARAMETERS']._serialized_start=129
|
||||
_globals['_PARAMETERS']._serialized_end=205
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=207
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=291
|
||||
_globals['_EMPTY']._serialized_start=293
|
||||
_globals['_EMPTY']._serialized_end=300
|
||||
_globals['_LEARNERSERVICE']._serialized_start=401
|
||||
_globals['_LEARNERSERVICE']._serialized_end=658
|
||||
_globals['_TRANSFERSTATE']._serialized_start=298
|
||||
_globals['_TRANSFERSTATE']._serialized_end=394
|
||||
_globals['_TRANSITION']._serialized_start=47
|
||||
_globals['_TRANSITION']._serialized_end=123
|
||||
_globals['_PARAMETERS']._serialized_start=125
|
||||
_globals['_PARAMETERS']._serialized_end=201
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=203
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=287
|
||||
_globals['_EMPTY']._serialized_start=289
|
||||
_globals['_EMPTY']._serialized_end=296
|
||||
_globals['_LEARNERSERVICE']._serialized_start=397
|
||||
_globals['_LEARNERSERVICE']._serialized_end=654
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from src.lerobot.transport import services_pb2 as src_dot_lerobot_dot_transport_dot_services__pb2
|
||||
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
@@ -18,7 +18,7 @@ except ImportError:
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in src/lerobot/transport/services_pb2_grpc.py depends on'
|
||||
+ f' but the generated code in lerobot/transport/services_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
@@ -38,23 +38,23 @@ class LearnerServiceStub:
|
||||
"""
|
||||
self.StreamParameters = channel.unary_stream(
|
||||
'/transport.LearnerService/StreamParameters',
|
||||
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/transport.LearnerService/SendTransitions',
|
||||
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/transport.LearnerService/SendInteractions',
|
||||
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/transport.LearnerService/Ready',
|
||||
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
@@ -93,23 +93,23 @@ def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamParameters,
|
||||
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.FromString,
|
||||
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Transition.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
@@ -139,8 +139,8 @@ class LearnerService:
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/StreamParameters',
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -166,8 +166,8 @@ class LearnerService:
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.LearnerService/SendTransitions',
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -193,8 +193,8 @@ class LearnerService:
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.LearnerService/SendInteractions',
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -220,8 +220,8 @@ class LearnerService:
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/Ready',
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
|
||||
@@ -111,35 +111,46 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
def custom_format(record):
|
||||
def init_logging(
|
||||
log_file: Path | None = None,
|
||||
display_pid: bool = False,
|
||||
console_level: str = "INFO",
|
||||
file_level: str = "DEBUG",
|
||||
):
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
formatter = logging.Formatter()
|
||||
formatter.format = custom_format
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
|
||||
|
||||
# Remove unused default handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# Write logs to console
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
console_handler.setLevel(console_level.upper())
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# Additionally write logs to file
|
||||
if log_file is not None:
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
file_handler.setLevel(file_level.upper())
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
|
||||
@@ -28,19 +28,35 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
|
||||
def log_rerun_data(observation: dict[str, Any], action: dict[str, Any]):
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, dict):
|
||||
# Handle dictionary of joint values
|
||||
for joint_name, joint_val in val.items():
|
||||
if isinstance(joint_val, (float, int)):
|
||||
rr.log(f"observation.{obs}.{joint_name}", rr.Scalar(float(joint_val)))
|
||||
elif isinstance(val, np.ndarray):
|
||||
if val.ndim == 1:
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
|
||||
else:
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
elif isinstance(val, dict):
|
||||
# Handle dictionary of joint values
|
||||
for joint_name, joint_val in val.items():
|
||||
if isinstance(joint_val, (float, int)):
|
||||
rr.log(f"action.{act}.{joint_name}", rr.Scalar(float(joint_val)))
|
||||
elif isinstance(val, np.ndarray):
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
elif isinstance(val, list):
|
||||
# Handle list of values
|
||||
for i, v in enumerate(val):
|
||||
if isinstance(v, (float, int)):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
|
||||
177
tests/async_inference/test_e2e.py
Normal file
177
tests/async_inference/test_e2e.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end test of the asynchronous inference stack (client ↔ server).
|
||||
|
||||
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
|
||||
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
|
||||
is to exercise the full communication loop:
|
||||
|
||||
1. Client sends policy specification → Server
|
||||
2. Client streams observations → Server
|
||||
3. Server streams action chunks → Client
|
||||
4. Client executes received actions
|
||||
|
||||
The test succeeds if at least one action is executed and the server records at
|
||||
least one predicted timestep - demonstrating that the gRPC round-trip works
|
||||
end-to-end using real (but lightweight) protocol messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from concurrent import futures
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if grpc is not available
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# End-to-end test
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_async_inference_e2e(monkeypatch):
|
||||
"""Tests the full asynchronous inference pipeline."""
|
||||
# Import grpc-dependent modules inside the test function
|
||||
import grpc
|
||||
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
|
||||
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
# Create a stub policy similar to test_policy_server.py
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self):
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def model(self, batch):
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Create PolicyServer instance with mock policy
|
||||
# ------------------------------------------------------------------
|
||||
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
policy_server = PolicyServer(policy_server_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
policy_server.policy = MockPolicy()
|
||||
policy_server.actions_per_chunk = 20
|
||||
policy_server.device = "cpu"
|
||||
|
||||
# Set up robot config and features
|
||||
robot_config = MockRobotConfig()
|
||||
mock_robot = make_robot_from_config(robot_config)
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
|
||||
policy_server.lerobot_features = lerobot_features
|
||||
|
||||
# Force server to produce deterministic action chunks in test mode
|
||||
policy_server.policy_type = "act"
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="test"):
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||
|
||||
# Build gRPC server running a PolicyServer
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
|
||||
# Use the host/port specified in the fixture's config
|
||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||
server.add_insecure_port(server_address)
|
||||
server.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Create a RobotClient around the MockRobot
|
||||
# ------------------------------------------------------------------
|
||||
client_config = RobotClientConfig(
|
||||
server_address=server_address,
|
||||
robot=robot_config,
|
||||
chunk_size_threshold=0.0,
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received without modifying RobotClient
|
||||
action_chunks_received = {"count": 0}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
# Start client threads
|
||||
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
|
||||
action_thread.start()
|
||||
control_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. System exchanges a few messages
|
||||
# ------------------------------------------------------------------
|
||||
# Wait for 5 seconds
|
||||
server.wait_for_termination(timeout=5)
|
||||
|
||||
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
|
||||
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Stop the system
|
||||
# ------------------------------------------------------------------
|
||||
client.stop()
|
||||
action_thread.join()
|
||||
control_thread.join()
|
||||
policy_server.stop()
|
||||
server.stop(grace=None)
|
||||
459
tests/async_inference/test_helpers.py
Normal file
459
tests/async_inference/test_helpers.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
observations_similar,
|
||||
prepare_image,
|
||||
prepare_raw_observation,
|
||||
raw_observation_to_observation,
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fps_tracker_first_observation():
|
||||
"""First observation should initialize timestamp and return 0 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
timestamp = 1000.0
|
||||
|
||||
metrics = tracker.calculate_fps_metrics(timestamp)
|
||||
|
||||
assert tracker.first_timestamp == timestamp
|
||||
assert tracker.total_obs_count == 1
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
assert metrics["target_fps"] == 30.0
|
||||
|
||||
|
||||
def test_fps_tracker_single_interval():
|
||||
"""Two observations 1 second apart should give 1 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# First observation at t=0
|
||||
metrics1 = tracker.calculate_fps_metrics(0.0)
|
||||
assert metrics1["avg_fps"] == 0.0
|
||||
|
||||
# Second observation at t=1 (1 second later)
|
||||
metrics2 = tracker.calculate_fps_metrics(1.0)
|
||||
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
|
||||
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_multiple_intervals():
|
||||
"""Multiple observations should calculate correct average FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
|
||||
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for i, ts in enumerate(timestamps):
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
if i == 0:
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
elif i == len(timestamps) - 1:
|
||||
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
|
||||
expected_fps = 2.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_irregular_intervals():
|
||||
"""FPS calculation should work with irregular time intervals."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
|
||||
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
|
||||
|
||||
for ts in timestamps:
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
|
||||
expected_fps = 4.0 / 3.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# TimedData helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timed_action_getters():
|
||||
"""TimedAction stores & returns timestamp, action tensor and timestep."""
|
||||
ts = time.time()
|
||||
action = torch.arange(10)
|
||||
ta = TimedAction(timestamp=ts, action=action, timestep=0)
|
||||
|
||||
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
torch.testing.assert_close(ta.get_action(), action)
|
||||
assert ta.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {"observation.state": torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to.get_observation() is obs_dict
|
||||
assert to.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_data_deserialization_data_getters():
|
||||
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
|
||||
|
||||
The async-inference stack uses ``pickle.dumps`` to move these objects across
|
||||
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
|
||||
This test ensures that the payload keeps its content intact after
|
||||
the (de)serialization round-trip.
|
||||
"""
|
||||
ts = time.time()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedAction
|
||||
# ------------------------------------------------------------------
|
||||
original_action = torch.randn(6)
|
||||
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
|
||||
|
||||
# Serialize → bytes → deserialize
|
||||
ta_bytes = pickle.dumps(ta_in) # nosec
|
||||
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
|
||||
|
||||
# Identity & content checks
|
||||
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert ta_out.get_timestep() == 13
|
||||
torch.testing.assert_close(ta_out.get_action(), original_action)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {"observation.state": torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
|
||||
|
||||
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# observations_similar()
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor) -> TimedObservation:
|
||||
"""Create a TimedObservation with raw robot observation format."""
|
||||
return TimedObservation(
|
||||
timestamp=time.time(),
|
||||
observation={
|
||||
"shoulder": state[0].item() if len(state) > 0 else 0.0,
|
||||
"elbow": state[1].item() if len(state) > 1 else 0.0,
|
||||
"wrist": state[2].item() if len(state) > 2 else 0.0,
|
||||
"gripper": state[3].item() if len(state) > 3 else 0.0,
|
||||
},
|
||||
timestep=0,
|
||||
)
|
||||
|
||||
|
||||
def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
}
|
||||
}
|
||||
|
||||
obs1 = _make_obs(torch.zeros(4))
|
||||
obs2 = _make_obs(0.5 * torch.ones(4))
|
||||
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
|
||||
|
||||
obs3 = _make_obs(2.0 * torch.ones(4))
|
||||
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# raw_observation_to_observation and helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_mock_robot_observation():
|
||||
"""Create a mock robot observation with motor positions and camera images."""
|
||||
return {
|
||||
"shoulder": 1.0,
|
||||
"elbow": 2.0,
|
||||
"wrist": 3.0,
|
||||
"gripper": 0.5,
|
||||
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
"observation.images.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_image():
|
||||
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
|
||||
# Create mock int8 image data
|
||||
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
|
||||
|
||||
processed = prepare_image(image_int8)
|
||||
|
||||
# Check dtype conversion
|
||||
assert processed.dtype == torch.float32
|
||||
|
||||
# Check normalization range
|
||||
assert processed.min() >= 0.0
|
||||
assert processed.max() <= 1.0
|
||||
|
||||
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
|
||||
if image_int8.max() == 255:
|
||||
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
|
||||
if image_int8.min() == 0:
|
||||
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
|
||||
|
||||
# Check memory contiguity
|
||||
assert processed.is_contiguous()
|
||||
|
||||
|
||||
def test_resize_robot_observation_image():
|
||||
"""Test image resizing from robot resolution to policy resolution."""
|
||||
# Create mock image: (H=480, W=640, C=3)
|
||||
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
|
||||
target_shape = (3, 224, 224) # (C, H, W)
|
||||
|
||||
resized = resize_robot_observation_image(original_image, target_shape)
|
||||
|
||||
# Check output shape matches target
|
||||
assert resized.shape == target_shape
|
||||
|
||||
# Check that original image had different dimensions
|
||||
assert original_image.shape != resized.shape
|
||||
|
||||
# Check that resizing preserves value range
|
||||
assert resized.min() >= 0
|
||||
assert resized.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_raw_observation():
|
||||
"""Test the preparation of raw robot observation to lerobot format."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert "observation.state" in prepared
|
||||
state = prepared["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert "observation.images.laptop" in prepared
|
||||
assert "observation.images.phone" in prepared
|
||||
|
||||
laptop_img = prepared["observation.images.laptop"]
|
||||
phone_img = prepared["observation.images.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
assert isinstance(phone_img, torch.Tensor)
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_basic():
|
||||
"""Test the main raw_observation_to_observation function."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert "observation.state" in observation
|
||||
assert "observation.images.laptop" in observation
|
||||
assert "observation.images.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation["observation.images.laptop"]
|
||||
phone_img = observation["observation.images.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check device placement
|
||||
assert laptop_img.device.type == device
|
||||
assert phone_img.device.type == device
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
|
||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
"""Test that non-tensor data (like task strings) is preserved."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
robot_obs["task"] = "pick up the red cube" # Add string instruction
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
assert observation["task"] == "pick up the red cube"
|
||||
assert isinstance(observation["task"], str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are properly moved to the specified device."""
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all tensors are on the correct device
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
"""Test that the function produces consistent results for the same input."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
|
||||
for key in obs1:
|
||||
if isinstance(obs1[key], torch.Tensor):
|
||||
torch.testing.assert_close(obs1[key], obs2[key])
|
||||
else:
|
||||
assert obs1[key] == obs2[key]
|
||||
|
||||
|
||||
def test_image_processing_pipeline_preserves_content():
|
||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||
# Create an image with a specific pattern
|
||||
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_img[25:75, 25:75, :] = 255 # White square in center
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
|
||||
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
|
||||
corner_val = processed_img[:, 5, 5].mean() # Corner
|
||||
|
||||
assert center_val > corner_val, "Image processing should preserve recognizable patterns"
|
||||
215
tests/async_inference/test_policy_server.py
Normal file
215
tests/async_inference/test_policy_server.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `PolicyServer` core logic.
|
||||
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros.
|
||||
Refer to tests/policies for tests of the individual policies supported."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation["observation.state"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# The server calls `policy.to(device)`. This stub ignores it.
|
||||
return self
|
||||
|
||||
def model(self, batch: dict) -> torch.Tensor:
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@require_package("grpc")
|
||||
def policy_server():
|
||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
|
||||
test_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
server = PolicyServer(test_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
server.policy = MockPolicy()
|
||||
server.actions_per_chunk = 20
|
||||
server.device = "cpu"
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
}
|
||||
}
|
||||
|
||||
return server
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
|
||||
"""Create a TimedObservation with a given state vector."""
|
||||
# Import only when needed
|
||||
from lerobot.scripts.server.helpers import TimedObservation
|
||||
|
||||
return TimedObservation(
|
||||
observation={
|
||||
"joint1": state[0].item() if len(state) > 0 else 0.0,
|
||||
"joint2": state[1].item() if len(state) > 1 else 0.0,
|
||||
"joint3": state[2].item() if len(state) > 2 else 0.0,
|
||||
"joint4": state[3].item() if len(state) > 3 else 0.0,
|
||||
"joint5": state[4].item() if len(state) > 4 else 0.0,
|
||||
"joint6": state[5].item() if len(state) > 5 else 0.0,
|
||||
},
|
||||
timestamp=time.time(),
|
||||
timestep=timestep,
|
||||
must_go=must_go,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_time_action_chunk(policy_server):
|
||||
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
|
||||
start_ts = time.time()
|
||||
start_t = 10
|
||||
# A chunk of 3 action tensors.
|
||||
action_tensors = [torch.randn(6) for _ in range(3)]
|
||||
|
||||
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
|
||||
|
||||
assert len(timed_actions) == 3
|
||||
# Check timesteps
|
||||
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
|
||||
# Check timestamps
|
||||
expected_timestamps = [
|
||||
start_ts,
|
||||
start_ts + policy_server.config.environment_dt,
|
||||
start_ts + 2 * policy_server.config.environment_dt,
|
||||
]
|
||||
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_must_go(policy_server):
|
||||
"""An observation with `must_go=True` is always enqueued."""
|
||||
obs = _make_obs(torch.zeros(6), must_go=True)
|
||||
assert policy_server._enqueue_observation(obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
assert policy_server.observation_queue.get_nowait() is obs
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_dissimilar(policy_server):
|
||||
"""A dissimilar observation (not `must_go`) is enqueued."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, dissimilar observation.
|
||||
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_is_skipped(policy_server):
|
||||
"""A similar observation (not `must_go`) is skipped."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, very similar observation.
|
||||
new_obs = _make_obs(torch.zeros(6) + 1e-4)
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is False
|
||||
assert policy_server.observation_queue.empty() is True
|
||||
|
||||
|
||||
def test_obs_sanity_checks(policy_server):
|
||||
"""Unit-test the private `_obs_sanity_checks` helper."""
|
||||
prev = _make_obs(torch.zeros(6), timestep=0)
|
||||
|
||||
# Case 1 – timestep already predicted
|
||||
policy_server._predicted_timesteps.add(1)
|
||||
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
|
||||
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
|
||||
|
||||
# Case 2 – observation too similar
|
||||
policy_server._predicted_timesteps.clear()
|
||||
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
|
||||
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
|
||||
|
||||
# Case 3 – genuinely new & dissimilar observation passes
|
||||
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
|
||||
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
|
||||
|
||||
|
||||
def test_predict_action_chunk(monkeypatch, policy_server):
|
||||
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
|
||||
# Import only when needed
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
|
||||
# Force server to act-style policy; patch method to return deterministic tensor
|
||||
policy_server.policy_type = "act"
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="act"):
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
obs = _make_obs(torch.zeros(6), timestep=5)
|
||||
timed_actions = policy_server._predict_action_chunk(obs)
|
||||
|
||||
assert len(timed_actions) == actions_per_chunk
|
||||
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
|
||||
|
||||
for i, ta in enumerate(timed_actions):
|
||||
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
234
tests/async_inference/test_robot_client.py
Normal file
234
tests/async_inference/test_robot_client.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
|
||||
|
||||
We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that
|
||||
no real hardware is accessed. Only the queue-update mechanism is verified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if grpc is not available
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def robot_client():
|
||||
"""Fresh `RobotClient` instance for each test case (no threads started).
|
||||
Uses DummyRobot."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
test_config = MockRobotConfig()
|
||||
|
||||
# gRPC channel is not actually used in tests, so using a dummy address
|
||||
test_config = RobotClientConfig(
|
||||
robot=test_config,
|
||||
server_address="localhost:9999",
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(test_config)
|
||||
|
||||
# Initialize attributes that are normally set in start() method
|
||||
client.chunks_received = 0
|
||||
client.available_actions_size = []
|
||||
|
||||
yield client
|
||||
|
||||
if client.robot.is_connected:
|
||||
client.stop()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_actions(start_ts: float, start_t: int, count: int):
|
||||
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
|
||||
from lerobot.scripts.server.helpers import TimedAction
|
||||
|
||||
fps = 30 # emulates most common frame-rate
|
||||
actions = []
|
||||
for i in range(count):
|
||||
timestep = start_t + i
|
||||
timestamp = start_ts + i * (1 / fps)
|
||||
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
|
||||
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
|
||||
return actions
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_action_queue_discards_stale(robot_client):
|
||||
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
|
||||
|
||||
# Pretend we already executed up to action #4
|
||||
robot_client.latest_action = 4
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
robot_client._aggregate_action_queues(incoming)
|
||||
|
||||
# Extract timesteps from queue
|
||||
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
|
||||
|
||||
assert resulting_timesteps == [5, 6, 7]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_old, weight_new",
|
||||
[
|
||||
(1.0, 0.0),
|
||||
(0.0, 1.0),
|
||||
(0.5, 0.5),
|
||||
(0.2, 0.8),
|
||||
(0.8, 0.2),
|
||||
(0.1, 0.9),
|
||||
(0.9, 0.1),
|
||||
],
|
||||
)
|
||||
def test_aggregate_action_queues_combines_actions_in_overlap(
|
||||
robot_client, weight_old: float, weight_new: float
|
||||
):
|
||||
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
|
||||
to the provided aggregate_fn, here tested with multiple coefficients."""
|
||||
from lerobot.scripts.server.helpers import TimedAction
|
||||
|
||||
robot_client.chunks_received = 0
|
||||
|
||||
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
|
||||
robot_client.latest_action = 4
|
||||
current_actions = _make_actions(
|
||||
start_ts=time.time(), start_t=5, count=2
|
||||
) # actions are [torch.ones(6), torch.ones(6), ...]
|
||||
current_actions = [
|
||||
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
|
||||
for a in current_actions
|
||||
]
|
||||
|
||||
for a in current_actions:
|
||||
robot_client.action_queue.put(a)
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
|
||||
nonoverlap_timesteps = [7]
|
||||
|
||||
robot_client._aggregate_action_queues(
|
||||
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
|
||||
)
|
||||
|
||||
queue_overlap_actions = []
|
||||
queue_non_overlap_actions = []
|
||||
for a in robot_client.action_queue.queue:
|
||||
if a.get_timestep() in overlap_timesteps:
|
||||
queue_overlap_actions.append(a)
|
||||
elif a.get_timestep() in nonoverlap_timesteps:
|
||||
queue_non_overlap_actions.append(a)
|
||||
|
||||
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
|
||||
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
|
||||
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[0].get_action(),
|
||||
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
|
||||
)
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[1].get_action(),
|
||||
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
|
||||
)
|
||||
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_size, queue_len, expected",
|
||||
[
|
||||
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
|
||||
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
|
||||
(10, 5, True),
|
||||
(10, 6, False),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
|
||||
# Clear any existing actions then fill with `queue_len` dummy entries ----
|
||||
robot_client.action_queue = Queue()
|
||||
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"g_threshold, expected",
|
||||
[
|
||||
# The condition is `queue_size / chunk_size <= g`.
|
||||
# Here, ratio = 6 / 10 = 0.6.
|
||||
(0.0, False), # 0.6 <= 0.0 is False
|
||||
(0.1, False),
|
||||
(0.2, False),
|
||||
(0.3, False),
|
||||
(0.4, False),
|
||||
(0.5, False),
|
||||
(0.6, True), # 0.6 <= 0.6 is True
|
||||
(0.7, True),
|
||||
(0.8, True),
|
||||
(0.9, True),
|
||||
(1.0, True),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
|
||||
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
|
||||
chunk_size = 10
|
||||
queue_len = 6
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
# This is the parameter we are testing
|
||||
robot_client._chunk_size_threshold = g_threshold
|
||||
|
||||
# Fill queue with dummy actions
|
||||
robot_client.action_queue = Queue()
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
@@ -394,56 +394,37 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
# @pytest.mark.skip("TODO after fix multidataset")
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
|
||||
# dummy padding dimensions (simulate training setup)
|
||||
MAX_ACTION_DIM = 14
|
||||
MAX_STATE_DIM = 30
|
||||
MAX_NUM_IMAGES = 3
|
||||
MAX_IMAGE_DIM = 224
|
||||
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_ids,
|
||||
max_action_dim=MAX_ACTION_DIM,
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_num_images=MAX_NUM_IMAGES,
|
||||
max_image_dim=MAX_IMAGE_DIM,
|
||||
)
|
||||
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_item, multi_item in zip(
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = multi_item.pop("dataset_index")
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
|
||||
# we ignore padding_mask and dataset_index keys in multi_item
|
||||
extra_keys = {k for k in multi_item if "padding_mask" in k}
|
||||
filtered_multi_keys = set(multi_item.keys()) - extra_keys
|
||||
assert set(sub_item.keys()) == filtered_multi_keys, "mismatch in keys"
|
||||
|
||||
for k in sub_item:
|
||||
if k not in multi_item:
|
||||
continue
|
||||
v1, v2 = sub_item[k], multi_item[k]
|
||||
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
|
||||
else:
|
||||
assert v1 == v2, f"value mismatch on key: {k}"
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
|
||||
@@ -219,7 +219,7 @@ def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
||||
|
||||
comm, error = bus._write(addr, length, id_, value)
|
||||
|
||||
assert mock_motors.stubs[stub].called
|
||||
assert mock_motors.stubs[stub].wait_called()
|
||||
assert comm == scs.COMM_SUCCESS
|
||||
assert error == 0
|
||||
|
||||
@@ -371,9 +371,9 @@ def test_reset_calibration(mock_motors, dummy_motors):
|
||||
|
||||
bus.reset_calibration()
|
||||
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs)
|
||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
|
||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_mins_stubs)
|
||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_maxes_stubs)
|
||||
|
||||
|
||||
def test_set_half_turn_homings(mock_motors, dummy_motors):
|
||||
@@ -410,7 +410,7 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
|
||||
|
||||
bus.reset_calibration.assert_called_once()
|
||||
assert mock_motors.stubs[read_pos_stub].called
|
||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
|
||||
|
||||
|
||||
def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||
|
||||
14
urdf/assets/base_motor_holder_so101_v1.part
Normal file
14
urdf/assets/base_motor_holder_so101_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "84d8ae1881704ebae1ffb70a",
|
||||
"documentMicroversion": "0eea3500852bdb2f58b1cb79",
|
||||
"documentVersion": "a5c3b0dfaa52ddd6829011cd",
|
||||
"elementId": "22efbe4e0bef24fcd20f96e5",
|
||||
"fullConfiguration": "default",
|
||||
"id": "MCOhripg0ry51VlsC",
|
||||
"isStandardContent": false,
|
||||
"name": "Base_motor_holder_SO101 v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/base_motor_holder_so101_v1.stl
LFS
Normal file
BIN
urdf/assets/base_motor_holder_so101_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/base_so101_v2.part
Normal file
14
urdf/assets/base_so101_v2.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "bf61a6bc85b1d1a8bf9ea51b",
|
||||
"documentMicroversion": "20484d37162a32a8a41a37f2",
|
||||
"documentVersion": "25801b070e5b360715de8a30",
|
||||
"elementId": "312f32f0073fa6e8e36fba7a",
|
||||
"fullConfiguration": "default",
|
||||
"id": "MY69cJlqvSzIiODdH",
|
||||
"isStandardContent": false,
|
||||
"name": "Base_SO101 v2 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/base_so101_v2.stl
LFS
Normal file
BIN
urdf/assets/base_so101_v2.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/motor_holder_so101_base_v1.part
Normal file
14
urdf/assets/motor_holder_so101_base_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "652d5731024e57367badfda6",
|
||||
"documentMicroversion": "56a8b8013480c176fd87df8d",
|
||||
"documentVersion": "984ac31c92cac3664c8effb3",
|
||||
"elementId": "6fb7b7f9315511b548d670ff",
|
||||
"fullConfiguration": "default",
|
||||
"id": "Mf4ZebMr4BkShucFj",
|
||||
"isStandardContent": false,
|
||||
"name": "Motor_holder_SO101_Base v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/motor_holder_so101_base_v1.stl
LFS
Normal file
BIN
urdf/assets/motor_holder_so101_base_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/motor_holder_so101_wrist_v1.part
Normal file
14
urdf/assets/motor_holder_so101_wrist_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "4bd66da73cacb4d946d43e44",
|
||||
"documentMicroversion": "2bf56247e58b70e90806e318",
|
||||
"documentVersion": "df78bb7089f1de7d5588d238",
|
||||
"elementId": "d7dfe76e402c21bbd8124e43",
|
||||
"fullConfiguration": "default",
|
||||
"id": "MN9BZ1p69dQQtKTjq",
|
||||
"isStandardContent": false,
|
||||
"name": "Motor_holder_SO101_Wrist v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/motor_holder_so101_wrist_v1.stl
LFS
Normal file
BIN
urdf/assets/motor_holder_so101_wrist_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/moving_jaw_so101_v1.part
Normal file
14
urdf/assets/moving_jaw_so101_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "46218c02ef80d36172edbb35",
|
||||
"documentMicroversion": "68b7d387e2500c451586ae59",
|
||||
"documentVersion": "79c101d1a0207b77362b561a",
|
||||
"elementId": "d4b1411d5d7333298f6e2458",
|
||||
"fullConfiguration": "default",
|
||||
"id": "MrHPLr9hZkrXwcSA4",
|
||||
"isStandardContent": false,
|
||||
"name": "Moving_Jaw_SO101 v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/moving_jaw_so101_v1.stl
LFS
Normal file
BIN
urdf/assets/moving_jaw_so101_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/rotation_pitch_so101_v1.part
Normal file
14
urdf/assets/rotation_pitch_so101_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "14078aa6723c502d07d6902e",
|
||||
"documentMicroversion": "c0fca717407275159bcc6ed7",
|
||||
"documentVersion": "3d9a887ff68fa477d98162b8",
|
||||
"elementId": "43d24b3857ff686b275578bf",
|
||||
"fullConfiguration": "default",
|
||||
"id": "MrQ6Kmk9QDZlwbp95",
|
||||
"isStandardContent": false,
|
||||
"name": "Rotation_Pitch_SO101 v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/rotation_pitch_so101_v1.stl
LFS
Normal file
BIN
urdf/assets/rotation_pitch_so101_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/sts3215_03a_no_horn_v1.part
Normal file
14
urdf/assets/sts3215_03a_no_horn_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "56e5f3702dad85e17841d2e2",
|
||||
"documentMicroversion": "7958a6acbc8e0d0a0a611746",
|
||||
"documentVersion": "29a4c51b8bf277a22743a333",
|
||||
"elementId": "8c14fb13a6557ec89ff5d227",
|
||||
"fullConfiguration": "default",
|
||||
"id": "MOcaIFg8XgL+Ybg9z",
|
||||
"isStandardContent": false,
|
||||
"name": "STS3215_03a_no_horn v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/sts3215_03a_no_horn_v1.stl
LFS
Normal file
BIN
urdf/assets/sts3215_03a_no_horn_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/sts3215_03a_v1.part
Normal file
14
urdf/assets/sts3215_03a_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "d2941bdba816affebdc6d6f0",
|
||||
"documentMicroversion": "5904ef3cea04a0d0bc88b698",
|
||||
"documentVersion": "dd4f7470101215836a4ae8c9",
|
||||
"elementId": "e670b72d49b06f88fad5dbd8",
|
||||
"fullConfiguration": "default",
|
||||
"id": "M5vQNpe0onRFueych",
|
||||
"isStandardContent": false,
|
||||
"name": "STS3215_03a v1 <5>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/sts3215_03a_v1.stl
LFS
Normal file
BIN
urdf/assets/sts3215_03a_v1.stl
LFS
Normal file
Binary file not shown.
14
urdf/assets/under_arm_so101_v1.part
Normal file
14
urdf/assets/under_arm_so101_v1.part
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"configuration": "default",
|
||||
"documentId": "9f5d6db47eb112442b9f130f",
|
||||
"documentMicroversion": "e99cf45162e34789bd99512b",
|
||||
"documentVersion": "817ebf29c5663d412edc0753",
|
||||
"elementId": "2813aaffe3c8a342616d3527",
|
||||
"fullConfiguration": "default",
|
||||
"id": "M9yAEiX02J3c4HqXa",
|
||||
"isStandardContent": false,
|
||||
"name": "Under_arm_SO101 v1 <1>",
|
||||
"partId": "JFD",
|
||||
"suppressed": false,
|
||||
"type": "Part"
|
||||
}
|
||||
BIN
urdf/assets/under_arm_so101_v1.stl
LFS
Normal file
BIN
urdf/assets/under_arm_so101_v1.stl
LFS
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user