mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
148 Commits
peft-preco
...
feat/anysk
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
764404a27e | ||
|
|
8b9451b585 | ||
|
|
ab4903e752 | ||
|
|
538cea6dbc | ||
|
|
5cd3572713 | ||
|
|
3399513e5e | ||
|
|
32fc4015ee | ||
|
|
cc72c813bf | ||
|
|
606f31a86e | ||
|
|
4933c9dcc7 | ||
|
|
7e25385024 | ||
|
|
cc70bff74d | ||
|
|
9f50913b9c | ||
|
|
4eb7694d47 | ||
|
|
edb5559b5b | ||
|
|
552ec76195 | ||
|
|
e75340b473 | ||
|
|
2a4c223ec7 | ||
|
|
1ee4d84f07 | ||
|
|
6bd40ca219 | ||
|
|
b879cf3d04 | ||
|
|
bd9e5c1a64 | ||
|
|
9271a0c900 | ||
|
|
af2f044f5a | ||
|
|
0caba222ef | ||
|
|
6d73f5bfe6 | ||
|
|
ef8f40c21b | ||
|
|
0232879245 | ||
|
|
2726b4e865 | ||
|
|
e126d35249 | ||
|
|
d7ae8cd699 | ||
|
|
2f96d8bf76 | ||
|
|
e129c71b4f | ||
|
|
a02d70389d | ||
|
|
0d4922ce49 | ||
|
|
eaeff78924 | ||
|
|
e2f3982e2c | ||
|
|
a73ac2bdbb | ||
|
|
95de732e55 | ||
|
|
b2383236ca | ||
|
|
4b98cc25c8 | ||
|
|
90780c4de8 | ||
|
|
6f6e046c53 | ||
|
|
8cd64eaad1 | ||
|
|
e620395416 | ||
|
|
0fbcbcdb2e | ||
|
|
674f5dfd75 | ||
|
|
7d430c8067 | ||
|
|
5f114c1d74 | ||
|
|
ad01ef19f4 | ||
|
|
59e8f4572c | ||
|
|
97e91698fb | ||
|
|
af0294198a | ||
|
|
421fdcce96 | ||
|
|
bb63ad9715 | ||
|
|
3c90a79c57 | ||
|
|
8e29c530ed | ||
|
|
b573b7a052 | ||
|
|
926184110b | ||
|
|
bf8ede852d | ||
|
|
f73db4394b | ||
|
|
bff91f9927 | ||
|
|
6d726266fd | ||
|
|
2962330bb1 | ||
|
|
067993bb11 | ||
|
|
e4dd00c8f5 | ||
|
|
e714ff22e2 | ||
|
|
3bbd161cfd | ||
|
|
6d7be63f59 | ||
|
|
b9d0dfb9a2 | ||
|
|
dce483060f | ||
|
|
c32b9182d9 | ||
|
|
a4d4ef0e7f | ||
|
|
9a5c96b2b1 | ||
|
|
0a6ca58299 | ||
|
|
688195fc46 | ||
|
|
99eb0bbafc | ||
|
|
16de8b3f19 | ||
|
|
580008663b | ||
|
|
52c424c5eb | ||
|
|
836195e59c | ||
|
|
be09a59e05 | ||
|
|
373a169bd2 | ||
|
|
00536c6c5b | ||
|
|
cdd3a859ef | ||
|
|
5276fc0d6f | ||
|
|
6a2882f978 | ||
|
|
8874547353 | ||
|
|
2864caad80 | ||
|
|
d998660aa1 | ||
|
|
7e5f3b35e9 | ||
|
|
01fea7c407 | ||
|
|
13bfee1aa4 | ||
|
|
79688a09f2 | ||
|
|
b2ff219624 | ||
|
|
66929c5935 | ||
|
|
5286ef8439 | ||
|
|
fe068df711 | ||
|
|
da41646073 | ||
|
|
46e19ae579 | ||
|
|
77dc49b3a3 | ||
|
|
33910673ec | ||
|
|
19dce78457 | ||
|
|
112b2d173a | ||
|
|
b825880c40 | ||
|
|
76d6b71b0a | ||
|
|
5de38813d9 | ||
|
|
6797ce615e | ||
|
|
a17df523e0 | ||
|
|
1c61b43b15 | ||
|
|
15724826dd | ||
|
|
2cdd9f43f7 | ||
|
|
d0f57f58d1 | ||
|
|
b8ec1152d4 | ||
|
|
6b8d4c75a6 | ||
|
|
d791a431fe | ||
|
|
473f1bd0e0 | ||
|
|
91ff9c4975 | ||
|
|
1d86c9b7f2 | ||
|
|
ba3d2148a3 | ||
|
|
8b6fc0ae05 | ||
|
|
242b65d2df | ||
|
|
ccfd609ece | ||
|
|
fbe4c8b94f | ||
|
|
4f7cd8d369 | ||
|
|
f844c7a458 | ||
|
|
7e9d05a799 | ||
|
|
ecd8cd23d2 | ||
|
|
a9d81e7f67 | ||
|
|
e2957d7783 | ||
|
|
963a3482fa | ||
|
|
603d44434f | ||
|
|
6106a8136c | ||
|
|
e670ac5daf | ||
|
|
75ab388ecd | ||
|
|
17c115c71f | ||
|
|
fc296548cb | ||
|
|
9701b9c273 | ||
|
|
6d0d65a5fe | ||
|
|
60efd875fa | ||
|
|
12043b3b5c | ||
|
|
a06f4b9140 | ||
|
|
20c22a2799 | ||
|
|
2f238fce15 | ||
|
|
ff271e8b51 | ||
|
|
a142c365dd | ||
|
|
b2ef6ae720 | ||
|
|
a64f2fd322 |
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -22,20 +22,21 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
|
||||
- Short, concrete bullets of the modifications (files/behaviour).
|
||||
- Short note if this introduces breaking changes and migration steps.
|
||||
|
||||
## How was this tested
|
||||
## How was this tested (or how to run locally)
|
||||
|
||||
- Tests added: list new tests or test files.
|
||||
- Manual checks / dataset runs performed.
|
||||
- Instructions for the reviewer
|
||||
|
||||
## How to run locally (reviewer)
|
||||
Example:
|
||||
|
||||
- Run the relevant tests:
|
||||
- Ran the relevant tests:
|
||||
|
||||
```bash
|
||||
pytest -q tests/ -k <keyword>
|
||||
```
|
||||
|
||||
- Run a quick example or CLI (if applicable):
|
||||
- Reproduce with a quick example or CLI (if applicable):
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
|
||||
7
.github/workflows/documentation.yml
vendored
7
.github/workflows/documentation.yml
vendored
@@ -33,6 +33,9 @@ on:
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
@@ -43,7 +46,7 @@ jobs:
|
||||
build_main_docs:
|
||||
name: Build Main Docs
|
||||
if: >
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch' || github.event_name == 'release') &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -51,7 +54,7 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module
|
||||
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
2
.github/workflows/fast_tests.yml
vendored
2
.github/workflows/fast_tests.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
19
.github/workflows/full_tests.yml
vendored
19
.github/workflows/full_tests.yml
vendored
@@ -61,7 +61,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -186,15 +186,18 @@ jobs:
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
# zizmor: ignore[template-injection]
|
||||
env:
|
||||
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }}
|
||||
run: |
|
||||
IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
|
||||
|
||||
IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2-)
|
||||
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" \
|
||||
-X POST \
|
||||
-d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
|
||||
-d "{\"username\": \"$DOCKERHUB_LEROBOT_USERNAME\", \"password\": \"$DOCKERHUB_LEROBOT_PASSWORD\"}" \
|
||||
https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||
|
||||
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
|
||||
@@ -205,7 +208,7 @@ jobs:
|
||||
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
|
||||
-H "Authorization: JWT ${TOKEN}" \
|
||||
-X DELETE \
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/$IMAGE_TAG)
|
||||
|
||||
if [ "$HTTP_RESPONSE" -eq 204 ]; then
|
||||
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
44
.github/workflows/issue_labeler.yml
vendored
44
.github/workflows/issue_labeler.yml
vendored
@@ -42,38 +42,26 @@ jobs:
|
||||
|
||||
// Keyword Heuristics
|
||||
|
||||
// Domain Specific
|
||||
if (matches(/\b(bug|error|issue|fault|crash|exception)\b/i)) labelsToAdd.add('bug');
|
||||
if (matches(/\b(feature|enhancement|improvement|support|implement|proposal)\b/i)) labelsToAdd.add('enhancement');
|
||||
if (matches(/\b(question|help|how to||clarify|explain|unclear)\b/i)) labelsToAdd.add('question');
|
||||
if (matches(/\b(maintenance|documentation|docs|readme|tutorial|guide|wiki)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(example|script|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples');
|
||||
if (matches(/\b(bug|error|crash|exception)\b/i)) labelsToAdd.add('bug');
|
||||
if (matches(/\b(new feature|enhancement|improvement|proposal|feature request)\b/i)) labelsToAdd.add('enhancement');
|
||||
if (matches(/\b(question|how to|clarify|explain|how do i|help me|question about)\b/i)) labelsToAdd.add('question');
|
||||
if (matches(/\b(documentation|docs?|readme|tutorial|wiki|typo|docstring)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(example|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples');
|
||||
if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset');
|
||||
if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation');
|
||||
if (matches(/\b(train|training|loss|optimizer|backward|gradient|wandb|sac)\b/i)) labelsToAdd.add('training');
|
||||
if (matches(/\b(rerun|plot|video|render|visualiz|gif)/i)) labelsToAdd.add('visualization');
|
||||
if (matches(/\b(camera|realsense|lidar|depth|sensor|imu|microphone|rgbd)\b/i)) labelsToAdd.add('sensors');
|
||||
if (matches(/\b(aloha|koch|so-100|so100|mobile|teleop|manipulator|robots?)\b/i)) labelsToAdd.add('robots');
|
||||
if (matches(/\b(train|training|optimizer|gradient|wandb|sac)\b/i)) labelsToAdd.add('training');
|
||||
if (matches(/\b(rerun|plot|render|rendering|visualizer)/i)) labelsToAdd.add('visualization');
|
||||
if (matches(/\b(cameras?|opencv|realsense|lidars?|sensors?|imus?|microphones?|rgbd|encoders?)\b/i)) labelsToAdd.add('sensors');
|
||||
if (matches(/\b(urdf|actuators?|calibration|end-effector|kinematics)\b/i)) labelsToAdd.add('robots');
|
||||
if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators');
|
||||
if (matches(/\b(policy|policies|p0licy)\b/i)) labelsToAdd.add('policies');
|
||||
if (matches(/\b(processors?|pipeline)\b/i)) labelsToAdd.add('processor');
|
||||
if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmark)\b/i)) labelsToAdd.add('evaluation');
|
||||
|
||||
// Infrastructure & Code Quality
|
||||
if (matches(/\b(policy|policies|model?)\b/i)) labelsToAdd.add('policies');
|
||||
if (matches(/\b(processor|pipeline|preprocessor|postprocessor)s?\b/i)) labelsToAdd.add('processor');
|
||||
if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmarks?)\b/i)) labelsToAdd.add('evaluation');
|
||||
if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests');
|
||||
if (matches(/\b(ci|github actions|workflow|gha|actions?|pipeline)\b/i)) {
|
||||
labelsToAdd.add('CI');
|
||||
labelsToAdd.add('github_actions');
|
||||
}
|
||||
if (matches(/\b(perf|latency|throughput|fps|speed|performance)\b/i)) labelsToAdd.add('performance');
|
||||
if (matches(/\b(dependency|requirements|pip|conda|install error|importerror|package not found)\b/i)) labelsToAdd.add('dependencies');
|
||||
if (matches(/\b(python|pyproject|requirements(\.txt)?|pip install|typing error)\b/i)) labelsToAdd.add('python');
|
||||
|
||||
// Documentation & Meta
|
||||
if (matches(/\b(doc|documentation|docs|readme|typo|how to)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(refactor|cleanup|restructure|rename|modernize code)\b/i)) labelsToAdd.add('refactor');
|
||||
if (matches(/\b(release|changelog|version bump|cut a release|tag v)\b/i)) labelsToAdd.add('release');
|
||||
if (matches(/\b(breaking change|major change)\b/i)) labelsToAdd.add('breaking change');
|
||||
if (matches(/\b(ci|github actions?|github workflows?|gha|docker|pypi)\b/i)) labelsToAdd.add('CI');
|
||||
if (matches(/\b(perf|latency|throughput|fps|speed|performance|slow|fast|slower|faster|memory usage)\b/i)) labelsToAdd.add('performance');
|
||||
if (matches(/\b(dependency|dependencies|pip|install error|importerror|package not found|pyproject)\b/i)) labelsToAdd.add('dependencies');
|
||||
if (matches(/\b(configuration|config|arguments?|input feature|dracuss)\b/i)) labelsToAdd.add('configuration');
|
||||
|
||||
// Apply Labels
|
||||
const labels = Array.from(labelsToAdd).filter(Boolean);
|
||||
|
||||
4
.github/workflows/nightly.yml
vendored
4
.github/workflows/nightly.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -87,7 +87,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
4
.github/workflows/quality.yml
vendored
4
.github/workflows/quality.yml
vendored
@@ -43,12 +43,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
7
.github/workflows/release.yml
vendored
7
.github/workflows/release.yml
vendored
@@ -38,12 +38,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -177,4 +177,3 @@ jobs:
|
||||
|
||||
# TODO(Steven): Publish draft/pre-release and to test pypi weekly
|
||||
# TODO(Steven): Separate build and publish job
|
||||
# TODO(Steven): Tag documentation with the same version as the package
|
||||
|
||||
2
.github/workflows/security.yml
vendored
2
.github/workflows/security.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses]
|
||||
uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
22
.github/workflows/unbound_deps_tests.yml
vendored
22
.github/workflows/unbound_deps_tests.yml
vendored
@@ -20,8 +20,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
# schedule:
|
||||
# - cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
@@ -101,7 +101,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -162,15 +162,19 @@ jobs:
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
# zizmor: ignore[template-injection]
|
||||
env:
|
||||
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }}
|
||||
run: |
|
||||
IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
|
||||
IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2)
|
||||
|
||||
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" \
|
||||
-X POST \
|
||||
-d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
|
||||
-d "{\"username\": \"$DOCKERHUB_LEROBOT_USERNAME\", \"password\": \"$DOCKERHUB_LEROBOT_PASSWORD\"}" \
|
||||
https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||
|
||||
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
|
||||
@@ -181,7 +185,7 @@ jobs:
|
||||
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
|
||||
-H "Authorization: JWT ${TOKEN}" \
|
||||
-X DELETE \
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/$IMAGE_TAG)
|
||||
|
||||
if [ "$HTTP_RESPONSE" -eq 204 ]; then
|
||||
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
352
CONTRIBUTING.md
352
CONTRIBUTING.md
@@ -1,323 +1,83 @@
|
||||
# How to contribute to 🤗 LeRobot?
|
||||
# How to contribute to 🤗 LeRobot
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
is thus not the only way to help the community. Answering questions, helping
|
||||
others, reaching out and improving the documentations are immensely valuable to
|
||||
the community.
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
It also helps us if you spread the word: reference the library from blog posts
|
||||
on the awesome projects it made possible, shout out on Twitter when it has
|
||||
helped you, or simply ⭐️ the repo to say "thank you".
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our
|
||||
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
|
||||
## Ways to Contribute
|
||||
|
||||
## You can contribute in so many ways!
|
||||
You can contribute in many ways:
|
||||
|
||||
Some of the ways you can contribute to 🤗 LeRobot:
|
||||
- **Fixing issues:** Resolve bugs or improve existing code.
|
||||
- **New features:** Develop new features.
|
||||
- **Extend:** Implement new models/policies, robots, or simulation environments and upload datasets to the Hugging Face Hub.
|
||||
- **Documentation:** Improve examples, guides, and docstrings.
|
||||
- **Feedback:** Submit tickets related to bugs or desired new features.
|
||||
|
||||
- Fixing outstanding issues with the existing code.
|
||||
- Implementing new models, datasets or simulation environments.
|
||||
- Contributing to the examples or to the documentation.
|
||||
- Submitting issues related to bugs or desired new features.
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
|
||||
|
||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
|
||||
## Development Setup
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
To contribute code, you need to set up a development environment.
|
||||
|
||||
## Submitting a new issue or feature request
|
||||
### 1. Fork and Clone
|
||||
|
||||
Do your best to follow these guidelines when submitting an issue or a feature
|
||||
request. It will make it easier for us to come back to you quickly and with good
|
||||
feedback.
|
||||
|
||||
### Did you find a bug?
|
||||
|
||||
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
|
||||
the problems they encounter. So thank you for reporting an issue.
|
||||
|
||||
First, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on Github under Issues).
|
||||
|
||||
Did not find it? :( So we can act quickly on it, please follow these steps:
|
||||
|
||||
- Include your **OS type and version**, the versions of **Python** and **PyTorch**.
|
||||
- A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
- The full traceback if an exception is raised.
|
||||
- Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
### Do you want a new feature?
|
||||
|
||||
A good feature request addresses the following points:
|
||||
|
||||
1. Motivation first:
|
||||
|
||||
- Is it related to a problem/frustration with the library? If so, please explain
|
||||
why. Providing a code snippet that demonstrates the problem is best.
|
||||
- Is it related to something you would need for a project? We'd love to hear
|
||||
about it!
|
||||
- Is it something you worked on and think could benefit the community?
|
||||
Awesome! Tell us what problem it solved for you.
|
||||
|
||||
2. Write a _paragraph_ describing the feature.
|
||||
3. Provide a **code snippet** that demonstrates its future use.
|
||||
4. In case this is related to a paper, please attach a link.
|
||||
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you
|
||||
post it.
|
||||
|
||||
## Adding new policies, datasets or environments
|
||||
|
||||
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/lerobot) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
git clone git@github.com:<your Github handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
|
||||
```bash
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry lock
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
6. Follow our style.
|
||||
|
||||
`lerobot` relies on `ruff` to format its source code
|
||||
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
|
||||
automatically as Git commit hooks.
|
||||
|
||||
Install `pre-commit` hooks:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
You can run these hooks whenever you need on staged files with:
|
||||
|
||||
```bash
|
||||
pre-commit
|
||||
```
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
7. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
8. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
Fork the repository on GitHub, then clone your fork:
|
||||
|
||||
```bash
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
git clone https://github.com/<your-handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
### 2. Environment Installation
|
||||
|
||||
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
|
||||
|
||||
## Running Tests & Quality Checks
|
||||
|
||||
### Code Style (Pre-commit)
|
||||
|
||||
Install `pre-commit` hooks to run checks automatically before you commit:
|
||||
|
||||
```bash
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||
To run checks manually on all files:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use `pytest`. First, ensure you have test artifacts by installing **git-lfs**:
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
Run the full suite (this may require extras installed):
|
||||
|
||||
```bash
|
||||
python -m pytest -sv ./tests
|
||||
pytest -sv ./tests
|
||||
```
|
||||
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
you're working on.
|
||||
Or run a specific test file during development:
|
||||
|
||||
```bash
|
||||
pytest -sv tests/test_specific_feature.py
|
||||
```
|
||||
|
||||
## Submitting Issues & Pull Requests
|
||||
|
||||
Use the templates for required fields and examples.
|
||||
|
||||
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
|
||||
|
||||
One member of the LeRobot team will then review your contribution.
|
||||
|
||||
Thank you for contributing to LeRobot!
|
||||
|
||||
13
README.md
13
README.md
@@ -10,6 +10,7 @@
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://discord.gg/q8Dzzpym3f)
|
||||
|
||||
</div>
|
||||
|
||||
@@ -99,11 +100,11 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -127,7 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
## Resources
|
||||
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
|
||||
|
||||
48
SECURITY.md
Normal file
48
SECURITY.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# Security Policy
|
||||
|
||||
## Project Status & Philosophy
|
||||
|
||||
`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasn’t been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues.
|
||||
|
||||
Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab.
|
||||
|
||||
The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
|
||||
#### Hugging Face Security Team
|
||||
|
||||
Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps.
|
||||
|
||||
#### Open Source Disclosures
|
||||
|
||||
If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch).
|
||||
|
||||
| Version | Supported |
|
||||
| -------- | --------- |
|
||||
| Latest | ✅ |
|
||||
| < Latest | ❌ |
|
||||
|
||||
## Secure Usage Guidelines
|
||||
|
||||
`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe.
|
||||
|
||||
### Remote Artefacts (Weights & Policies)
|
||||
|
||||
Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||
|
||||
`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots.
|
||||
|
||||
To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files.
|
||||
|
||||
### Remote Code
|
||||
|
||||
Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code.
|
||||
|
||||
Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository.
|
||||
219
benchmarks/audio/run_microphone_benchmark.py
Normal file
219
benchmarks/audio/run_microphone_benchmark.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
microphones_configs: dict[str, MicrophoneConfig],
|
||||
audio_chunks_number: int,
|
||||
audio_chunks_duration: float,
|
||||
repetitions: int,
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/audio_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
microphones = make_microphones_from_configs(microphones_configs)
|
||||
|
||||
# Connect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
all_audio_chunks = []
|
||||
for i in range(repetitions):
|
||||
print(f"Repetition {i + 1}/{repetitions}...")
|
||||
|
||||
# Create audio chunks
|
||||
audio_chunks = {}
|
||||
for microphone_key in microphones:
|
||||
audio_chunks.update({microphone_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
microphones,
|
||||
output_files=[
|
||||
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
|
||||
],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
for j in range(audio_chunks_number):
|
||||
precise_sleep(audio_chunks_duration)
|
||||
|
||||
for microphone_key, microphone in microphones.items():
|
||||
audio_chunk = microphone.read()
|
||||
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
|
||||
audio_chunks[microphone_key].append(audio_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone_key in microphones:
|
||||
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
|
||||
|
||||
all_audio_chunks.append(audio_chunks)
|
||||
|
||||
# Disconnect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.disconnect()
|
||||
|
||||
# Compute statistics
|
||||
cmap = plt.get_cmap("tab10")
|
||||
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
|
||||
chunk_length = np.zeros((repetitions, len(microphones)))
|
||||
record_length = np.zeros((repetitions, len(microphones)))
|
||||
for i in range(repetitions):
|
||||
for j, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
# Get recorded audio chunks
|
||||
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
|
||||
|
||||
# Load recorded file
|
||||
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
|
||||
if recorded_data.ndim == 1:
|
||||
recorded_data = np.expand_dims(recorded_data, axis=1)
|
||||
|
||||
record_length[i, j] = recorded_data.shape[0]
|
||||
chunk_length[i, j] = recorded_audio_chunks.shape[0]
|
||||
|
||||
for k, (chunk_data, record_data) in enumerate(
|
||||
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
|
||||
):
|
||||
# Plot audio chunks and recorded data
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(chunk_data)) / microphone.sample_rate,
|
||||
chunk_data,
|
||||
label=f"audio chunks - channel {k}",
|
||||
color=cmap(2 * k),
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
record_data,
|
||||
label=f"recorded data - channel {k}",
|
||||
linestyle="dashed",
|
||||
color=cmap(2 * k + 1),
|
||||
)
|
||||
|
||||
# Plot absolute difference (errors should be located at the end of the recordings)
|
||||
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
|
||||
chunk_data = np.append(
|
||||
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
else:
|
||||
record_data = np.append(
|
||||
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
np.abs(chunk_data - record_data),
|
||||
label=f"differences - channel {k}",
|
||||
color="red",
|
||||
linestyle="dotted",
|
||||
)
|
||||
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
|
||||
ax[i, j].legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
# Print statistics
|
||||
differences = record_length - chunk_length
|
||||
for i, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
print(
|
||||
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(
|
||||
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
|
||||
print(
|
||||
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--microphones_indices",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_sample_rate",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument("--audio_chunks_number", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--audio_chunks_duration",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetitions",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["microphones_configs"] = {}
|
||||
for index, sample_rate, channels in zip(
|
||||
args["microphones_indices"],
|
||||
args["microphones_sample_rate"],
|
||||
args["microphones_channels"],
|
||||
strict=False,
|
||||
):
|
||||
microphone_config = PortAudioMicrophoneConfig(
|
||||
microphone_index=index,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
|
||||
args.pop("microphones_indices")
|
||||
args.pop("microphones_sample_rate")
|
||||
args.pop("microphones_channels")
|
||||
|
||||
main(**args)
|
||||
136
benchmarks/audio/run_tactile_benchmark.py
Normal file
136
benchmarks/audio/run_tactile_benchmark.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from lerobot.microphones.anyskin import AnyskinSensorConfig
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
sensors_configs: dict[str, MicrophoneConfig],
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/tactile_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
sensors = make_microphones_from_configs(sensors_configs)
|
||||
|
||||
# Connect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.connect()
|
||||
|
||||
# Create audio chunks
|
||||
data_chunks = {}
|
||||
for sensor_key in sensors:
|
||||
data_chunks.update({sensor_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
sensors,
|
||||
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
precise_sleep(10.0)
|
||||
|
||||
for sensor_key, sensor in sensors.items():
|
||||
data_chunk = sensor.read()
|
||||
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
|
||||
data_chunks[sensor_key].append(data_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(sensors)
|
||||
|
||||
for sensor_key in sensors:
|
||||
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
|
||||
|
||||
# Disconnect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.disconnect()
|
||||
|
||||
for sensor_key in sensors:
|
||||
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
|
||||
print(f"{sensor_key} - samples {data.shape[0]}")
|
||||
print(f"{sensor_key} - sample rate {sample_rate}")
|
||||
print(f"{sensor_key} - data {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sensors_ports",
|
||||
type=str,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_baud_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_sample_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["sensors_configs"] = {}
|
||||
for port, baud_rate, sample_rate, channels in zip(
|
||||
args["sensors_ports"],
|
||||
args["sensors_baud_rate"],
|
||||
args["sensors_sample_rate"],
|
||||
args["sensors_channels"],
|
||||
strict=False,
|
||||
):
|
||||
channels = [1, 2, 3, 4, 5]
|
||||
sensor_config = AnyskinSensorConfig(
|
||||
sensor_port=port,
|
||||
baud_rate=baud_rate,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
|
||||
args.pop("sensors_ports")
|
||||
args.pop("sensors_baud_rate")
|
||||
args.pop("sensors_sample_rate")
|
||||
args.pop("sensors_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -73,7 +73,7 @@ ENV HOME=/home/user_lerobot \
|
||||
RUN uv venv --python python${PYTHON_VERSION}
|
||||
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
@@ -59,7 +59,7 @@ ENV HOME=/home/user_lerobot \
|
||||
RUN uv venv
|
||||
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
title: Train RL in Simulation
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
@@ -35,12 +37,16 @@
|
||||
title: SmolVLA
|
||||
- local: pi0
|
||||
title: π₀ (Pi0)
|
||||
- local: pi0fast
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: walloss
|
||||
title: WALL-OSS
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: sarm
|
||||
@@ -57,6 +63,8 @@
|
||||
title: Environments from the Hub
|
||||
- local: envhub_leisaac
|
||||
title: Control & Train Robots in Sim (LeIsaac)
|
||||
- local: envhub_isaaclab_arena
|
||||
title: NVIDIA IsaacLab Arena Environments
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
- local: metaworld
|
||||
|
||||
@@ -169,7 +169,7 @@ python -m lerobot.async_inference.robot_client \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
|
||||
@@ -12,23 +12,42 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
|
||||
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
|
||||
The robot needs the [Frodobots SDK](https://github.com/frodobots-org/earth-rovers-sdk) running on your computer. Here's how:
|
||||
|
||||
1. Download and install the SDK:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Frodobots/earth-rovers-sdk.git
|
||||
git clone https://github.com/frodobots-org/earth-rovers-sdk.git
|
||||
cd earth-rovers-sdk
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Start the SDK:
|
||||
2. Save Credentials:
|
||||
|
||||
Write your .env variables with the SDK API key and bot name provided by the Frodobots team.
|
||||
|
||||
```bash
|
||||
SDK_API_TOKEN=your_sdk_api_token_here
|
||||
BOT_SLUG=your_bot_slug_here
|
||||
CHROME_EXECUTABLE_PATH=/path/to/chrome_or_chromium
|
||||
# Default value is MAP_ZOOM_LEVEL=18 https://wiki.openstreetmap.org/wiki/Zoom_levels
|
||||
MAP_ZOOM_LEVEL=18
|
||||
MISSION_SLUG=your_mission_slug_here
|
||||
# Image quality between 0.1 and 1.0 (default: 0.8)
|
||||
# Recommended: 0.8 for better performance
|
||||
IMAGE_QUALITY=0.8
|
||||
# Image format: jpeg, png or webp (default: png)
|
||||
# Recommended: jpeg for better performance and lower bandwidth usage
|
||||
IMAGE_FORMAT=jpeg
|
||||
```
|
||||
|
||||
3. Start the SDK:
|
||||
|
||||
```bash
|
||||
hypercorn main:app --reload
|
||||
```
|
||||
|
||||
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
|
||||
4. Open your web browser and go to `http://localhost:8000`, then click "Join"
|
||||
|
||||
The SDK gives you:
|
||||
|
||||
|
||||
@@ -2,14 +2,32 @@
|
||||
|
||||
The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
|
||||
|
||||
## Overview
|
||||
## What is EnvHub?
|
||||
|
||||
With EnvHub, you can:
|
||||
EnvHub lets you create custom robotics simulation environments with your own robot models and scenarios, and make them easily usable by anyone through the LeRobot framework.
|
||||
|
||||
- Load environments from the Hub instantly
|
||||
- Share your custom simulation tasks with the community
|
||||
- Version control your environments using Git
|
||||
- Distribute complex physics simulations without packaging hassles
|
||||
EnvHub packages are stored on the Hugging Face Hub, and can be seamlessly pulled and used in your AI robotics projects through LeRobot with a single line of code.
|
||||
|
||||
Thanks to EnvHub, you can:
|
||||
|
||||
1. **Create and publish environments** to the Hugging Face Hub as Git repositories, and distribute complex physics simulations without packaging hassles
|
||||
2. **Load environments** dynamically, without installing them as packages
|
||||
3. **Version and track** environment changes using Git semantics
|
||||
4. **Discover** new simulation tasks shared by the community
|
||||
|
||||
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, or create your own custom robot and environment without worrying about dependency conflicts or complex installation procedures.
|
||||
|
||||
When you create an EnvHub package, you can build anything you want inside it and use any simulation tool you like: this is your own space to play with. The only requirement is that the package contains an `env.py` file that defines the environment and allows LeRobot to load and use your EnvHub package.
|
||||
|
||||
This `env.py` file needs to expose a small API so LeRobot can load and run it. In particular, you must provide a `make_env(n_envs: int = 1, use_async_envs: bool = False)` or `make_env(n_envs: int = 1, use_async_envs: bool = False, cfg: EnvConfig)` function, which is the main entry point for LeRobot. It should return one of:
|
||||
|
||||
- A `gym.vector.VectorEnv` (most common)
|
||||
- A single `gym.Env` (will be automatically wrapped)
|
||||
- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
|
||||
|
||||
You can also pass an `EnvConfig` object to `make_env` to configure the environment (e.g. the number of environments, task, camera name, initial states, control mode, episode length, etc.).
|
||||
|
||||
Finally, your environment must implement the standard `gym.vector.VectorEnv` interface so it works with LeRobot, including methods like `reset` and `step`.
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -29,17 +47,6 @@ env = make_env("lerobot/cartpole-env", trust_remote_code=True)
|
||||
hash for reproducibility and security.
|
||||
</Tip>
|
||||
|
||||
## What is EnvHub?
|
||||
|
||||
EnvHub is a framework that allows researchers and developers to:
|
||||
|
||||
1. **Publish environments** to the Hugging Face Hub as Git repositories
|
||||
2. **Load environments** dynamically without installing them as packages
|
||||
3. **Version and track** environment changes using Git semantics
|
||||
4. **Discover** new simulation tasks shared by the community
|
||||
|
||||
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
To make your environment loadable from the Hub, your repository must contain at minimum:
|
||||
|
||||
510
docs/source/envhub_isaaclab_arena.mdx
Normal file
510
docs/source/envhub_isaaclab_arena.mdx
Normal file
@@ -0,0 +1,510 @@
|
||||
# NVIDIA IsaacLab Arena & LeRobot
|
||||
|
||||
LeRobot EnvHub now supports **GPU-accelerated simulation** with IsaacLab Arena for policy evaluation at scale.
|
||||
Train and evaluate imitation learning policies with high-fidelity simulation — all integrated into the LeRobot ecosystem.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/nvidia/isaaclab-arena-envs/resolve/main/assets/Gr1OpenMicrowaveEnvironment.png"
|
||||
alt="IsaacLab Arena - GR1 Microwave Environment"
|
||||
style={{ maxWidth: "100%", borderRadius: "8px", marginBottom: "1rem" }}
|
||||
/>
|
||||
|
||||
[IsaacLab Arena](https://github.com/isaac-sim/IsaacLab-Arena) integrates with NVIDIA IsaacLab to provide:
|
||||
|
||||
- 🤖 **Humanoid embodiments**: GR1, G1, Galileo with various configurations
|
||||
- 🎯 **Manipulation & loco-manipulation tasks**: Door opening, pick-and-place, button pressing, and more
|
||||
- ⚡ **GPU-accelerated rollouts**: Parallel environment execution on NVIDIA GPUs
|
||||
- 🖼️ **RTX Rendering**: Evaluate vision-based policies with realistic rendering, reflections and refractions
|
||||
- 📦 **LeRobot-compatible datasets**: Ready for training with GR00T N1x, PI0, SmolVLA, ACT, and Diffusion policies
|
||||
- 🔄 **EnvHub integration**: Load environments from HuggingFace EnvHub with one line
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Hardware requirements are shared with Isaac Sim, and are detailed in [Isaac Sim Requirements](https://docs.isaacsim.omniverse.nvidia.com/5.1.0/installation/requirements.html).
|
||||
|
||||
- NVIDIA GPU with CUDA support
|
||||
- NVIDIA driver compatible with IsaacSim 5.1.0
|
||||
- Linux (Ubuntu 22.04 / 24.04)
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
# 1. Create conda environment
|
||||
conda create -y -n lerobot-arena python=3.11
|
||||
conda activate lerobot-arena
|
||||
conda install -y -c conda-forge ffmpeg=7.1.1
|
||||
|
||||
# 2. Install Isaac Sim 5.1.0
|
||||
pip install "isaacsim[all,extscache]==5.1.0" --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Accept NVIDIA EULA (required)
|
||||
export ACCEPT_EULA=Y
|
||||
export PRIVACY_CONSENT=Y
|
||||
|
||||
# 3. Install IsaacLab 2.3.0
|
||||
git clone https://github.com/isaac-sim/IsaacLab.git
|
||||
cd IsaacLab
|
||||
git checkout v2.3.0
|
||||
./isaaclab.sh -i
|
||||
cd ..
|
||||
|
||||
# 4. Install IsaacLab Arena
|
||||
git clone https://github.com/isaac-sim/IsaacLab-Arena.git
|
||||
cd IsaacLab-Arena
|
||||
git checkout release/0.1.1
|
||||
pip install -e .
|
||||
cd ..
|
||||
|
||||
|
||||
# 5. Install LeRobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e .
|
||||
cd ..
|
||||
|
||||
|
||||
# 6. Install additional dependencies
|
||||
pip install onnxruntime==1.23.2 lightwheel-sdk==1.0.1 vuer[all]==0.0.70 qpsolvers==4.8.1
|
||||
pip install numpy==1.26.0 # Isaac Sim 5.1 depends on numpy==1.26.0, this will be fixed in next release
|
||||
```
|
||||
|
||||
## Evaluating Policies
|
||||
|
||||
### Pre-trained Policies
|
||||
|
||||
The following trained policies are available:
|
||||
|
||||
| Policy | Architecture | Task | Link |
|
||||
| :-------------------------- | :----------- | :------------ | :----------------------------------------------------------------------- |
|
||||
| pi05-arena-gr1-microwave | PI0.5 | GR1 Microwave | [HuggingFace](https://huggingface.co/nvidia/pi05-arena-gr1-microwave) |
|
||||
| smolvla-arena-gr1-microwave | SmolVLA | GR1 Microwave | [HuggingFace](https://huggingface.co/nvidia/smolvla-arena-gr1-microwave) |
|
||||
|
||||
### Evaluate SmolVLA
|
||||
|
||||
```bash
|
||||
pip install -e ".[smolvla]"
|
||||
pip install numpy==1.26.0 # revert numpy to version 1.26
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=nvidia/smolvla-arena-gr1-microwave \
|
||||
--env.type=isaaclab_arena \
|
||||
--env.hub_path=nvidia/isaaclab-arena-envs \
|
||||
--rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \
|
||||
--policy.device=cuda \
|
||||
--env.environment=gr1_microwave \
|
||||
--env.embodiment=gr1_pink \
|
||||
--env.object=mustard_bottle \
|
||||
--env.headless=false \
|
||||
--env.enable_cameras=true \
|
||||
--env.video=true \
|
||||
--env.video_length=10 \
|
||||
--env.video_interval=15 \
|
||||
--env.state_keys=robot_joint_pos \
|
||||
--env.camera_keys=robot_pov_cam_rgb \
|
||||
--trust_remote_code=True \
|
||||
--eval.batch_size=1
|
||||
```
|
||||
|
||||
### Evaluate PI0.5
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
pip install numpy==1.26.0 # revert numpy to version 1.26
|
||||
```
|
||||
|
||||
<Tip>PI0.5 requires disabling torch compile for evaluation:</Tip>
|
||||
|
||||
```bash
|
||||
TORCH_COMPILE_DISABLE=1 TORCHINDUCTOR_DISABLE=1 lerobot-eval \
|
||||
--policy.path=nvidia/pi05-arena-gr1-microwave \
|
||||
--env.type=isaaclab_arena \
|
||||
--env.hub_path=nvidia/isaaclab-arena-envs \
|
||||
--rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \
|
||||
--policy.device=cuda \
|
||||
--env.environment=gr1_microwave \
|
||||
--env.embodiment=gr1_pink \
|
||||
--env.object=mustard_bottle \
|
||||
--env.headless=false \
|
||||
--env.enable_cameras=true \
|
||||
--env.video=true \
|
||||
--env.video_length=15 \
|
||||
--env.video_interval=15 \
|
||||
--env.state_keys=robot_joint_pos \
|
||||
--env.camera_keys=robot_pov_cam_rgb \
|
||||
--trust_remote_code=True \
|
||||
--eval.batch_size=1
|
||||
```
|
||||
|
||||
<Tip>
|
||||
To change the number of parallel environments, use the ```--eval.batch_size```
|
||||
flag.
|
||||
</Tip>
|
||||
|
||||
### What to Expect
|
||||
|
||||
During evaluation, you will see a progress bar showing the running success rate:
|
||||
|
||||
```
|
||||
Stepping through eval batches: 8%|██████▍ | 4/50 [00:45<08:06, 10.58s/it, running_success_rate=25.0%]
|
||||
```
|
||||
|
||||
### Video Recording
|
||||
|
||||
To enable video recording during evaluation, add the following flags to your command:
|
||||
|
||||
```bash
|
||||
--env.video=true \
|
||||
--env.video_length=15 \
|
||||
--env.video_interval=15
|
||||
```
|
||||
|
||||
For more details on video recording, see the [IsaacLab Recording Documentation](https://isaac-sim.github.io/IsaacLab/main/source/how-to/record_video.html).
|
||||
|
||||
<Tip>
|
||||
When running headless with `--env.headless=true`, you must also enable cameras explicitly for camera enabled environments:
|
||||
|
||||
```bash
|
||||
--env.headless=true --env.enable_cameras=true
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
### Output Directory
|
||||
|
||||
Evaluation videos are saved to the output directory with the following structure:
|
||||
|
||||
```
|
||||
outputs/eval/<date>/<timestamp>_<env>_<policy>/videos/<task>_<env_id>/eval_episode_<n>.mp4
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
outputs/eval/2026-01-02/14-38-01_isaaclab_arena_smolvla/videos/gr1_microwave_0/eval_episode_0.mp4
|
||||
```
|
||||
|
||||
## Training Policies
|
||||
|
||||
To learn more about training policies with LeRobot, please refer to the training documentation:
|
||||
|
||||
- [SmolVLA](./smolvla)
|
||||
- [Pi0.5](./pi05)
|
||||
- [GR00T N1.5](./groot)
|
||||
|
||||
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
|
||||
|
||||
| Dataset | Description | Frames |
|
||||
| :-------------------------------------------------------------------------------------------------------- | :------------------------- | :----- |
|
||||
| [Arena-GR1-Manipulation-Task](https://huggingface.co/datasets/nvidia/Arena-GR1-Manipulation-Task-v3) | GR1 microwave manipulation | ~4K |
|
||||
| [Arena-G1-Loco-Manipulation-Task](https://huggingface.co/datasets/nvidia/Arena-G1-Loco-Manipulation-Task) | G1 loco-manipulation | ~4K |
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Full Configuration Options
|
||||
|
||||
```python
|
||||
from lerobot.envs.configs import IsaaclabArenaEnv
|
||||
|
||||
config = IsaaclabArenaEnv(
|
||||
# Environment selection
|
||||
environment="gr1_microwave", # Task environment
|
||||
embodiment="gr1_pink", # Robot embodiment
|
||||
object="power_drill", # Object to manipulate
|
||||
|
||||
# Simulation settings
|
||||
episode_length=300, # Max steps per episode
|
||||
headless=True, # Run without GUI
|
||||
device="cuda:0", # GPU device
|
||||
seed=42, # Random seed
|
||||
|
||||
# Observation configuration
|
||||
state_keys="robot_joint_pos", # State observation keys (comma-separated)
|
||||
camera_keys="robot_pov_cam_rgb", # Camera observation keys (comma-separated)
|
||||
state_dim=54, # Expected state dimension
|
||||
action_dim=36, # Expected action dimension
|
||||
camera_height=512, # Camera image height
|
||||
camera_width=512, # Camera image width
|
||||
enable_cameras=True, # Enable camera observations
|
||||
|
||||
# Video recording
|
||||
video=False, # Enable video recording
|
||||
video_length=100, # Frames per video
|
||||
video_interval=200, # Steps between recordings
|
||||
|
||||
# Advanced
|
||||
mimic=False, # Enable mimic mode
|
||||
teleop_device=None, # Teleoperation device
|
||||
disable_fabric=False, # Disable fabric optimization
|
||||
enable_pinocchio=True, # Enable Pinocchio for IK
|
||||
)
|
||||
```
|
||||
|
||||
### Using Environment Hub directly for advanced usage
|
||||
|
||||
Create a file called `test_env_load_arena.py` or [download from the EnvHub](https://huggingface.co/nvidia/isaaclab-arena-envs/blob/main/tests/test_env_load_arena.py):
|
||||
|
||||
```python
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
import torch
|
||||
import tqdm
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: EvalPipelineConfig):
|
||||
"""Run random action rollout for IsaacLab Arena environment."""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
env_dict = make_env(
|
||||
cfg.env,
|
||||
n_envs=cfg.env.num_envs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
env = next(iter(env_dict.values()))[0]
|
||||
env.reset()
|
||||
for _ in tqdm.tqdm(range(cfg.env.episode_length)):
|
||||
with torch.inference_mode():
|
||||
actions = env.action_space.sample()
|
||||
obs, rewards, terminated, truncated, info = env.step(actions)
|
||||
if terminated.any() or truncated.any():
|
||||
obs, info = env.reset()
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
Run with:
|
||||
|
||||
```bash
|
||||
python test_env_load_arena.py \
|
||||
--env.environment=g1_locomanip_pnp \
|
||||
--env.embodiment=gr1_pink \
|
||||
--env.object=cracker_box \
|
||||
--env.num_envs=4 \
|
||||
--env.enable_cameras=true \
|
||||
--env.seed=1000 \
|
||||
--env.video=true \
|
||||
--env.video_length=10 \
|
||||
--env.video_interval=15 \
|
||||
--env.headless=false \
|
||||
--env.hub_path=nvidia/isaaclab-arena-envs \
|
||||
--env.type=isaaclab_arena
|
||||
```
|
||||
|
||||
## Creating New Environments
|
||||
|
||||
First create a new IsaacLab Arena environment by following the [IsaacLab Arena Documentation](https://isaac-sim.github.io/IsaacLab-Arena/release/0.1.1/index.html).
|
||||
|
||||
Clone our EnvHub repo:
|
||||
|
||||
```bash
|
||||
git clone https://huggingface.co/nvidia/isaaclab-arena-envs
|
||||
```
|
||||
|
||||
Modify the `example_envs.yaml` file based on your new environment.
|
||||
[Upload](./envhub#step-3-upload-to-the-hub) your modified repo to HuggingFace EnvHub.
|
||||
|
||||
<Tip>
|
||||
Your IsaacLab Arena environment code must be locally available during
|
||||
evaluation. Users can clone your environment repository separately, or you can
|
||||
bundle the environment code and assets directly in your EnvHub repo.
|
||||
</Tip>
|
||||
|
||||
Then, when evaluating, use your new environment:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--env.hub_path=<your-env-hub-path>/isaaclab-arena-envs \
|
||||
--env.environment=<your new environment> \
|
||||
...other flags...
|
||||
```
|
||||
|
||||
We look forward to your contributions!
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
Reduce `batch_size` or use a GPU with more VRAM:
|
||||
|
||||
```bash
|
||||
--eval.batch_size=1
|
||||
```
|
||||
|
||||
### EULA not accepted
|
||||
|
||||
Set environment variables before running:
|
||||
|
||||
```bash
|
||||
export ACCEPT_EULA=Y
|
||||
export PRIVACY_CONSENT=Y
|
||||
```
|
||||
|
||||
### Video recording not working
|
||||
|
||||
Enable cameras when running headless:
|
||||
|
||||
```bash
|
||||
--env.video=true --env.enable_cameras=true --env.headless=true
|
||||
```
|
||||
|
||||
### Policy output dimension mismatch
|
||||
|
||||
Ensure `action_dim` matches your policy:
|
||||
|
||||
```bash
|
||||
--env.action_dim=36
|
||||
```
|
||||
|
||||
### libGLU.so.1 Errors during Isaac Sim initialization
|
||||
|
||||
Ensure you have the following dependencies installed, this is likely to happen on headless machines.
|
||||
|
||||
```bash
|
||||
sudo apt update && sudo apt install -y libglu1-mesa libxt6
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [EnvHub Documentation](./envhub.mdx) - General EnvHub usage
|
||||
- [IsaacLab Arena GitHub](https://github.com/isaac-sim/IsaacLab-Arena)
|
||||
- [IsaacLab Documentation](https://isaac-sim.github.io/IsaacLab/)
|
||||
|
||||
## Lightwheel LW-BenchHub
|
||||
|
||||
[Lightwheel](https://www.lightwheel.ai) is bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem.
|
||||
LW-BenchHub collects and generates large-scale datasets via teleoperation that comply with the LeRobot specification, enabling out-of-the-box training and evaluation workflows.
|
||||
With the unified interface provided by EnvHub, developers can quickly build end-to-end experimental pipelines.
|
||||
|
||||
### Install
|
||||
|
||||
Assuming you followed the [Installation](#installation) steps, you can install LW-BenchHub with:
|
||||
|
||||
```bash
|
||||
conda install pinocchio -c conda-forge -y
|
||||
pip install numpy==1.26.0 # revert numpy to version 1.26
|
||||
|
||||
sudo apt-get install git-lfs && git lfs install
|
||||
|
||||
git clone https://github.com/LightwheelAI/lw_benchhub
|
||||
git lfs pull # Ensure LFS files (e.g., .usd assets) are downloaded
|
||||
|
||||
cd lw_benchhub
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more detailed instructions, please refer to the [LW-BenchHub Documentation](https://docs.lightwheel.net/lw_benchhub/usage/Installation).
|
||||
|
||||
### Lightwheel Tasks Dataset
|
||||
|
||||
LW-BenchHub datasets are available on HuggingFace Hub:
|
||||
|
||||
| Dataset | Description | Tasks | Frames |
|
||||
| :------------------------------------------------------------------------------------------------------------ | :---------------------- | :---- | :----- |
|
||||
| [Lightwheel-Tasks-X7S](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-X7S) | X7S LIBERO and RoboCasa | 117 | ~10.3M |
|
||||
| [Lightwheel-Tasks-Double-Piper](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-Double-Piper) | Double-Piper LIBERO | 130 | ~6.0M |
|
||||
| [Lightwheel-Tasks-G1-Controller](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-G1-Controller) | G1-Controller LIBERO | 62 | ~2.7M |
|
||||
| [Lightwheel-Tasks-G1-WBC](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-G1-WBC) | G1-WBC RoboCasa | 32 | ~1.5M |
|
||||
|
||||
For training policies, refer to the [Training Policies](#training-policies) section.
|
||||
|
||||
### Evaluating Policies
|
||||
|
||||
#### Pre-trained Policies
|
||||
|
||||
The following trained policies are available:
|
||||
|
||||
| Policy | Architecture | Task | Layout | Robot | Link |
|
||||
| :----------------------- | :----------- | :----------------------------- | :--------- | :-------------- | :------------------------------------------------------------------------------------ |
|
||||
| smolvla-double-piper-pnp | SmolVLA | L90K1PutTheBlackBowlOnThePlate | libero-1-1 | DoublePiper-Abs | [HuggingFace](https://huggingface.co/LightwheelAI/smolvla-double-piper-pnp/tree/main) |
|
||||
|
||||
#### Evaluate SmolVLA
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=LightwheelAI/smolvla-double-piper-pnp \
|
||||
--env.type=isaaclab_arena \
|
||||
--rename_map='{"observation.images.left_hand_camera_rgb": "observation.images.left_hand", "observation.images.right_hand_camera_rgb": "observation.images.right_hand", "observation.images.first_person_camera_rgb": "observation.images.first_person"}' \
|
||||
--env.hub_path=LightwheelAI/lw_benchhub_env \
|
||||
--env.kwargs='{"config_path": "configs/envhub/example.yml"}' \
|
||||
--trust_remote_code=true \
|
||||
--env.state_keys=joint_pos \
|
||||
--env.action_dim=12 \
|
||||
--env.camera_keys=left_hand_camera_rgb,right_hand_camera_rgb,first_person_camera_rgb \
|
||||
--policy.device=cuda \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=100
|
||||
```
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
Evaluation can be quickly launched by modifying the `robot`, `task`, and `layout` settings in the configuration file.
|
||||
|
||||
#### Full Configuration Options
|
||||
|
||||
```yml
|
||||
# =========================
|
||||
# Basic Settings
|
||||
# =========================
|
||||
disable_fabric: false
|
||||
device: cuda:0
|
||||
sensitivity: 1.0
|
||||
step_hz: 50
|
||||
enable_cameras: true
|
||||
execute_mode: eval
|
||||
episode_length_s: 20.0 # Episode length in seconds, increase if episodes timeout during eval
|
||||
|
||||
# =========================
|
||||
# Robot Settings
|
||||
# =========================
|
||||
robot: DoublePiper-Abs # Robot type, DoublePiper-Abs, X7S-Abs, G1-Controller or G1-Controller-DecoupledWBC
|
||||
robot_scale: 1.0
|
||||
|
||||
# =========================
|
||||
# Task & Scene Settings
|
||||
# =========================
|
||||
task: L90K1PutTheBlackBowlOnThePlate # Task name
|
||||
scene_backend: robocasa
|
||||
task_backend: robocasa
|
||||
debug_assets: null
|
||||
layout: libero-1-1 # Layout and style ID
|
||||
sources:
|
||||
- objaverse
|
||||
- lightwheel
|
||||
- aigen_objs
|
||||
object_projects: []
|
||||
usd_simplify: false
|
||||
seed: 42
|
||||
|
||||
# =========================
|
||||
# Object Placement Retry Settings
|
||||
# =========================
|
||||
max_scene_retry: 4
|
||||
max_object_placement_retry: 3
|
||||
|
||||
resample_objects_placement_on_reset: true
|
||||
resample_robot_placement_on_reset: true
|
||||
|
||||
# =========================
|
||||
# Replay Configuration Settings
|
||||
# =========================
|
||||
replay_cfgs:
|
||||
add_camera_to_observation: true
|
||||
render_resolution: [640, 480]
|
||||
```
|
||||
|
||||
### See Also
|
||||
|
||||
- [LW-BenchHub GitHub](https://github.com/LightwheelAI/LW-BenchHub)
|
||||
- [LW-BenchHub Documentation](https://docs.lightwheel.net/lw_benchhub/)
|
||||
@@ -137,7 +137,8 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
make_teleoperator_from_config,
|
||||
so101_leader,
|
||||
so_leader,
|
||||
bi_so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -196,7 +197,7 @@ def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
|
||||
obs, info = env.reset()
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
@@ -222,7 +223,7 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
|
||||
def main():
|
||||
teleoperate(TeleoperateConfig(
|
||||
teleop=so101_leader.SO101LeaderConfig(
|
||||
teleop=so_leader.SO101LeaderConfig(
|
||||
port="/dev/ttyACM0",
|
||||
id='leader',
|
||||
use_degrees=False,
|
||||
|
||||
@@ -12,6 +12,12 @@ Developers and researchers can post-train GR00T N1.5 with their own real or synt
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
|
||||
alt="An overview of GR00T"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
|
||||
|
||||
- Real captured data from robots.
|
||||
@@ -103,7 +109,7 @@ Once you have trained your model using your parameters you can run inference in
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
|
||||
@@ -58,8 +58,8 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
|
||||
from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
|
||||
from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader
|
||||
from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
@@ -195,9 +195,9 @@ lerobot-record \
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -408,8 +408,8 @@ lerobot-replay \
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -432,7 +432,7 @@ for idx in range(dataset.num_frames):
|
||||
}
|
||||
robot.send_action(action)
|
||||
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
```
|
||||
@@ -531,8 +531,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -18,7 +18,7 @@ If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus inter
|
||||
- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos
|
||||
|
||||
Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API.
|
||||
For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so101_follower/so101_follower.py)
|
||||
For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so_follower/so101_follower/so101_follower.py)
|
||||
|
||||
Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial):
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ lerobot-calibrate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
|
||||
from lerobot.teleoperators.so_leader import SO100LeaderConfig, SO100Leader
|
||||
|
||||
config = SO100LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
|
||||
62
docs/source/peft_training.mdx
Normal file
62
docs/source/peft_training.mdx
Normal file
@@ -0,0 +1,62 @@
|
||||
# Parameter efficient fine-tuning with 🤗 PEFT
|
||||
|
||||
[🤗 PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting
|
||||
large pretrained models such as pre-trained policies (e.g., SmolVLA, π₀, ...) to new tasks without training all
|
||||
of the model's parameters while yielding comparable performance.
|
||||
|
||||
Install the `lerobot[peft]` optional package to enable PEFT support.
|
||||
|
||||
To read about all the possible methods of adaption, please refer to the [🤗 PEFT docs](https://huggingface.co/docs/peft/index).
|
||||
|
||||
## Training SmolVLA
|
||||
|
||||
In this section we'll show you how to train a pre-trained SmolVLA policy with PEFT on the libero dataset.
|
||||
For brevity we're only training on the `libero_spatial` subset. We will use `lerobot/smolvla_base` as the model
|
||||
to parameter efficiently fine-tune:
|
||||
|
||||
```
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--policy.repo_id=your_hub_name/my_libero_smolvla \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--policy.output_features=null \
|
||||
--policy.input_features=null \
|
||||
--policy.optimizer_lr=1e-3 \
|
||||
--policy.scheduler_decay_lr=1e-4 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--steps=100000 \
|
||||
--batch_size=32 \
|
||||
--peft.method_type=LORA \
|
||||
--peft.r=64
|
||||
```
|
||||
|
||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
||||
the closer you get to full fine-tuning
|
||||
|
||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||
if you want to see a specific PEFT method supported.
|
||||
|
||||
By default, PEFT will target the `q_proj` and `v_proj` layers of the LM expert in SmolVLA. It will also target the
|
||||
state and action projection matrices as they are most likely task-dependent. If you need to target different layers
|
||||
you can use `--peft.target_modules` to specify which layers to target. You can refer to the respective PEFT method's
|
||||
documentation to see what inputs are supported, (e.g., [LoRA's target_modules documentation](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules)).
|
||||
Usually a list of suffixes or a regex are supported. For example, to target the MLPs of the `lm_expert` instead of
|
||||
the `q` and `v` projections, use:
|
||||
|
||||
```
|
||||
--peft.target_modules='(model\.vlm_with_expert\.lm_expert\..*\.(down|gate|up)_proj|.*\.(state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out))'
|
||||
```
|
||||
|
||||
In case you need to fully fine-tune a layer instead of just adapting it, you can supply a list of layer suffixes
|
||||
to the `--peft.full_training_modules` parameter:
|
||||
|
||||
```
|
||||
--peft.full_training_modules=["state_proj"]
|
||||
```
|
||||
|
||||
The learning rate and the scheduled target learning rate can usually be scaled by a factor of 10 compared to the
|
||||
learning rate used for full fine-tuning (e.g., 1e-4 normal, so 1e-3 using LoRA).
|
||||
@@ -44,7 +44,7 @@ Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`.
|
||||
|
||||
Teleoperation example:
|
||||
|
||||
```36:43:examples/phone_so100_teleop.py
|
||||
```python
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
@@ -103,7 +103,7 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
|
||||
- Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame.
|
||||
|
||||
```examples/phone_to_so100/teleoperate.py
|
||||
```python
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
@@ -114,7 +114,7 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
|
||||
- The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs.
|
||||
|
||||
```src/lerobot/teleoperators/phone/phone_processor.py
|
||||
```python
|
||||
action["enabled"] = enabled
|
||||
action["target_x"] = -pos[1] if enabled else 0.0
|
||||
action["target_y"] = pos[0] if enabled else 0.0
|
||||
@@ -127,7 +127,7 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
|
||||
- The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed.
|
||||
|
||||
```examples/phone_to_so100/teleoperate.py
|
||||
```python
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
@@ -138,7 +138,7 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
|
||||
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
|
||||
|
||||
```examples/phone_to_so100/teleoperate.py
|
||||
```python
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
@@ -147,7 +147,7 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
|
||||
- The `GripperVelocityToJoint` step turns a velocity‑like gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied.
|
||||
|
||||
```examples/phone_to_so100/teleoperate.py
|
||||
```python
|
||||
GripperVelocityToJoint(speed_factor=20.0)
|
||||
```
|
||||
|
||||
@@ -157,7 +157,7 @@ We use different IK initial guesses in the kinematic steps. As initial guess eit
|
||||
|
||||
- Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame.
|
||||
|
||||
```examples/phone_to_so100/record.py
|
||||
```python
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
@@ -167,7 +167,7 @@ We use different IK initial guesses in the kinematic steps. As initial guess eit
|
||||
|
||||
- Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback.
|
||||
|
||||
```examples/phone_to_so100/replay.py
|
||||
```python
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
|
||||
@@ -6,6 +6,12 @@
|
||||
|
||||
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png"
|
||||
alt="An overview of Pi0"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
### The Vision for Physical Intelligence
|
||||
|
||||
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
|
||||
@@ -64,6 +70,8 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.train_expert_only=false \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
@@ -79,6 +87,15 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
|
||||
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
### Training Parameters Explained
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ----------------------- | ------- | ------------------------------------------- |
|
||||
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
|
||||
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
|
||||
|
||||
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
|
||||
@@ -67,6 +67,8 @@ python src/lerobot/scripts/lerobot_train.py\
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.train_expert_only=false \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
@@ -82,6 +84,15 @@ python src/lerobot/scripts/lerobot_train.py\
|
||||
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
|
||||
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
### Training Parameters Explained
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ----------------------- | ------- | ------------------------------------------- |
|
||||
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
|
||||
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
|
||||
|
||||
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
|
||||
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
|
||||
246
docs/source/pi0fast.mdx
Normal file
246
docs/source/pi0fast.mdx
Normal file
@@ -0,0 +1,246 @@
|
||||
# π₀-FAST (Pi0-FAST)
|
||||
|
||||
π₀-FAST is a **Vision-Language-Action model for general robot control** that uses autoregressive next-token prediction to model continuous robot actions.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀-FAST combines the power of Vision-Language Models with a novel action tokenization approach called **FAST (Frequency-space Action Sequence Tokenization)**. This enables training autoregressive VLAs on highly dexterous tasks that are impossible with standard binning-based discretization, while training **up to 5x faster** than diffusion-based approaches like π₀.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png"
|
||||
alt="An overview of Pi0-FAST"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
### Why FAST?
|
||||
|
||||
Standard approaches for robot action tokenization use simple per-dimension, per-timestep binning schemes. While passable for simple behaviors, this rapidly breaks down for complex and dexterous skills that require precision and high-frequency control.
|
||||
|
||||
FAST solves this by compressing action sequences using signal processing techniques, resulting in a dense sequence of action tokens that can be predicted autoregressively—just like language tokens.
|
||||
|
||||
### How FAST Tokenization Works
|
||||
|
||||
The FAST tokenizer compresses action sequences through the following steps:
|
||||
|
||||
1. **Normalize**: Take a continuous action chunk of shape `(H, D)` where `H` is the horizon and `D` is the action dimension. Normalize using one of the supported normalization methods (Quantiles recommended to handle outliers).
|
||||
|
||||
2. **Discrete Cosine Transform (DCT)**: Apply DCT (via scipy) to each action dimension separately. DCT is a compression algorithm commonly used in image and audio codecs (JPEG, MP3).
|
||||
|
||||
3. **Quantization**: Round and remove insignificant coefficients for each action dimension, producing a sparse frequency matrix.
|
||||
|
||||
4. **Flatten**: Flatten the matrix into a 1D vector, with low-frequency components first.
|
||||
|
||||
5. **Byte Pair Encoding (BPE)**: Train a BPE tokenizer to compress the DCT coefficients into dense action tokens, typically achieving **10x compression** over prior tokenization approaches.
|
||||
|
||||
This approach can transform **any existing VLM** into a VLA by training it to predict these FAST tokens.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install π₀-FAST dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training a Custom FAST Tokenizer
|
||||
|
||||
You have two options for the FAST tokenizer:
|
||||
|
||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
|
||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||
|
||||
### Training Your Own Tokenizer
|
||||
|
||||
```bash
|
||||
lerobot-train-tokenizer \
|
||||
--repo_id "user/my-lerobot-dataset" \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:6" \
|
||||
--vocab_size 1024 \
|
||||
--scale 10.0 \
|
||||
--normalization_mode QUANTILES \
|
||||
--output_dir "./my_fast_tokenizer" \
|
||||
--push_to_hub \
|
||||
--hub_repo_id "username/my-action-tokenizer"
|
||||
```
|
||||
|
||||
### Key Tokenizer Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| ---------------------- | --------------------------------------------------------------------------------- | ------------ |
|
||||
| `--repo_id` | LeRobot dataset repository ID | Required |
|
||||
| `--action_horizon` | Number of future actions in each chunk | `10` |
|
||||
| `--encoded_dims` | Comma-separated dimension ranges to encode (e.g., `"0:6,7:23"`) | `"0:6,7:23"` |
|
||||
| `--vocab_size` | BPE vocabulary size | `1024` |
|
||||
| `--scale` | DCT scaling factor for quantization | `10.0` |
|
||||
| `--normalization_mode` | Normalization mode (`MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`, `IDENTITY`) | `QUANTILES` |
|
||||
| `--sample_fraction` | Fraction of chunks to sample per episode | `0.1` |
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀-FAST in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi0_fast
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training π₀-FAST, you can use the LeRobot training script:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0_fast \
|
||||
--output_dir=./outputs/pi0fast_training \
|
||||
--job_name=pi0fast_training \
|
||||
--policy.pretrained_path=lerobot/pi0_fast_base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.max_action_tokens=256 \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
|
||||
## Inference
|
||||
|
||||
### KV-Caching for Fast Inference
|
||||
|
||||
π₀-FAST supports **KV-caching**, a widely used optimization in LLM inference. This caches the key-value pairs from the attention mechanism, avoiding redundant computation during autoregressive decoding.
|
||||
|
||||
```python
|
||||
# KV-caching is enabled by default
|
||||
policy.use_kv_cache=true
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
|
||||
```python
|
||||
from lerobot.policies.pi0_fast import PI0FastPolicy, PI0FastConfig
|
||||
|
||||
# Load the policy
|
||||
policy = PI0FastPolicy.from_pretrained("your-model-path")
|
||||
|
||||
# During inference
|
||||
actions = policy.predict_action_chunk(batch)
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
π₀-FAST uses a PaliGemma-based architecture:
|
||||
|
||||
- **Vision Encoder**: SigLIP vision tower for image understanding
|
||||
- **Language Model**: Gemma 2B for processing language instructions and predicting action tokens
|
||||
|
||||
The model takes images, text instructions, and robot state as input, and outputs discrete FAST tokens that are decoded back to continuous actions.
|
||||
|
||||
## Configuration Options
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| -------------------- | ----------------------------------------------- | ---------- |
|
||||
| `paligemma_variant` | VLM backbone variant (`gemma_300m`, `gemma_2b`) | `gemma_2b` |
|
||||
| `max_state_dim` | Maximum state vector dimension (padded) | `32` |
|
||||
| `max_action_dim` | Maximum action vector dimension (padded) | `32` |
|
||||
| `temperature` | Sampling temperature (0.0 for greedy) | `0.0` |
|
||||
| `max_decoding_steps` | Maximum decoding steps | `256` |
|
||||
| `use_kv_cache` | Enable KV caching for faster inference | `true` |
|
||||
|
||||
## Comparison with π₀
|
||||
|
||||
| Feature | π₀ | π₀-FAST |
|
||||
| --------------------- | ------------------------- | ---------------------------- |
|
||||
| Action Representation | Flow Matching (Diffusion) | Autoregressive Tokens (FAST) |
|
||||
| Training Speed | 1x | **5x faster** |
|
||||
| Dexterity | High | High |
|
||||
| Inference Method | Iterative Denoising | Autoregressive Decoding |
|
||||
| KV-Caching | N/A | Supported |
|
||||
|
||||
## Reproducing π₀Fast results
|
||||
|
||||
We reproduce the results of π₀Fast on the LIBERO benchmark using the LeRobot implementation. We take the LeRobot PiFast base model [lerobot/pi0fast-base](https://huggingface.co/lerobot/pi0fast-base) and finetune for an additional 40kk steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
|
||||
|
||||
The finetuned model can be found here:
|
||||
|
||||
- **π₀Fast LIBERO**: [lerobot/pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)
|
||||
|
||||
With the following training command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=lerobot/libero \
|
||||
--output_dir=outputs/libero_pi0fast \
|
||||
--job_name=libero_pi0fast \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=100000 \
|
||||
--save_freq=20000 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
--policy.scheduler_warmup_steps=4000 \
|
||||
--policy.scheduler_decay_steps=100000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.max_action_tokens=256 \
|
||||
--policy.empty_cameras=1 \
|
||||
```
|
||||
|
||||
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
||||
|
||||
```bash
|
||||
tasks="libero_object,libero_spatial,libero_goal,libero_10"
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--policy.max_action_tokens=256 \
|
||||
--env.type=libero \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--env.task=${tasks} \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--rename_map='{"observation.images.image":"observation.images.base_0_rgb","observation.images.image2":"observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
|
||||
|
||||
### Results
|
||||
|
||||
We obtain the following results on the LIBERO benchmark:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| ----------- | -------------- | ------------- | ----------- | --------- | -------- |
|
||||
| **π₀-fast** | 70.0 | 100.0 | 100.0 | 60.0 | **82.5** |
|
||||
|
||||
The full evaluation output folder, including videos, is available [here](https://drive.google.com/drive/folders/1HXpwPTRm4hx6g1sF2P7OOqGG0TwPU7LQ?usp=sharing)
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
|
||||
## References
|
||||
|
||||
- [FAST: Efficient Robot Action Tokenization](https://www.physicalintelligence.company/research/fast) - Physical Intelligence Blog
|
||||
- [OpenPI Repository](https://github.com/Physical-Intelligence/openpi) - Original implementation
|
||||
- [FAST Tokenizer on Hugging Face](https://huggingface.co/physical-intelligence/fast) - Pre-trained tokenizer
|
||||
45
docs/source/policy_walloss_README.md
Normal file
45
docs/source/policy_walloss_README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# WALL-OSS
|
||||
|
||||
This repository contains the Hugging Face port of [**WALL-OSS**](https://x2robot.com/en/research/68bc2cde8497d7f238dde690), a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | Description |
|
||||
| ------------------ | ----------------------------------------------------- |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Paper: https://arxiv.org/pdf/2509.11766
|
||||
|
||||
Official Repository: https://github.com/X-Square-Robot/wall-x
|
||||
|
||||
Hugging Face: https://huggingface.co/x-square-robot
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zhai2025igniting,
|
||||
title = {Igniting VLMs Toward the Embodied Space},
|
||||
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
|
||||
journal = {arXiv preprint arXiv:2509.11766},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x).
|
||||
@@ -30,7 +30,7 @@ Each of these pipelines handle different conversions between different action an
|
||||
|
||||
Below is an example of the three pipelines that we use in the phone to SO-100 follower examples:
|
||||
|
||||
```69:90:examples/phone_so100_record.py
|
||||
```python
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
@@ -84,7 +84,7 @@ Dataset features are determined by the keys saved in the dataset. Each step can
|
||||
|
||||
Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples:
|
||||
|
||||
```src/lerobot/robots/so100_follower/robot_kinematic_processor.py
|
||||
```python
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
@@ -103,7 +103,7 @@ Here we declare what PolicyFeatures we modify in this step, so we know what feat
|
||||
|
||||
Below is an example of how we aggregate and merge features in the phone to SO-100 record example:
|
||||
|
||||
```121:145:examples/phone_so100_record.py
|
||||
```python
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
|
||||
@@ -38,6 +38,7 @@ docker run --rm -it \
|
||||
start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
|
||||
>
|
||||
> ```
|
||||
@@ -141,7 +142,7 @@ If you choose this option but still want to use the VR teleoperation application
|
||||
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.id=r2-0000 \
|
||||
@@ -150,6 +151,7 @@ python -m lerobot.record \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.with_mobile_base=false \
|
||||
--robot.with_torso_camera=true \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
@@ -165,7 +167,7 @@ python -m lerobot.record \
|
||||
**Extended setup overview (all options included):**
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=true \
|
||||
@@ -177,6 +179,8 @@ python -m lerobot.record \
|
||||
--robot.with_left_teleop_camera=true \
|
||||
--robot.with_right_teleop_camera=true \
|
||||
--robot.with_torso_camera=false \
|
||||
--robot.camera_width=640 \
|
||||
--robot.camera_height=480 \
|
||||
--robot.disable_torque_on_disconnect=false \
|
||||
--robot.max_relative_target=5.0 \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
@@ -212,9 +216,10 @@ Must be set to true if a compliant Reachy 2 is used to control another one.
|
||||
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
|
||||
To avoid this, you can exclude specific parts from recording and replay using:
|
||||
|
||||
````
|
||||
```bash
|
||||
--robot.with_<part>=false
|
||||
```,
|
||||
```
|
||||
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
It determine whether the corresponding part is recorded in the observations. True if not set.
|
||||
|
||||
@@ -222,49 +227,60 @@ By default, **all parts are recorded**.
|
||||
|
||||
The same per-part mechanism is available in `reachy2_teleoperator` as well.
|
||||
|
||||
````
|
||||
|
||||
```bash
|
||||
--teleop.with\_<part>
|
||||
|
||||
```
|
||||
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
Determine whether the corresponding part is recorded in the actions. True if not set.
|
||||
|
||||
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
|
||||
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
> For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
|
||||
##### Use the relevant cameras
|
||||
|
||||
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
|
||||
You can do the same for **cameras**. Enable or disable each camera with default parameters using:
|
||||
|
||||
```bash
|
||||
--robot.with_left_teleop_camera=<true|false> \
|
||||
--robot.with_right_teleop_camera=<true|false> \
|
||||
--robot.with_torso_camera=<true|false>
|
||||
```
|
||||
|
||||
--robot.with_left_teleop_camera=<true|false>
|
||||
--robot.with_right_teleop_camera=<true|false>
|
||||
--robot.with_torso_camera=<true|false>
|
||||
By default, no camera is recorded, all camera arguments are set to `false`.
|
||||
If you want to, you can use custom `width` and `height` parameters for Reachy 2's cameras using the `--robot.camera_width` & `--robot.camera_height` argument:
|
||||
|
||||
````
|
||||
```bash
|
||||
--robot.camera_width=1920 \
|
||||
--robot.camera_height=1080
|
||||
```
|
||||
|
||||
This will change the resolution of all 3 default robot cameras (enabled by the above bool arguments).
|
||||
|
||||
If you want, you can add additional cameras other than the ones in the robot as usual with:
|
||||
|
||||
```bash
|
||||
--robot.cameras="{ extra: {type: opencv, index_or_path: 42, width: 640, height: 480, fps: 30}}" \
|
||||
```
|
||||
|
||||
## Step 2: Replay
|
||||
|
||||
Make sure the robot is configured with the same parts as the dataset:
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=false \
|
||||
--robot.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.episode=0
|
||||
--display_data=true
|
||||
````
|
||||
```
|
||||
|
||||
## Step 3: Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/reachy2_test \
|
||||
@@ -277,10 +293,9 @@ python -m lerobot.scripts.train \
|
||||
## Step 4: Evaluate
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-eval \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=pollen_robotics/eval_record_test \
|
||||
--dataset.single_task="Evaluate reachy2 policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
|
||||
@@ -4,6 +4,12 @@ SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework fo
|
||||
|
||||
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-sarm.png"
|
||||
alt="An overview of SARM"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
## Why Reward Models?
|
||||
|
||||
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
|
||||
|
||||
@@ -103,7 +103,7 @@ lerobot-setup-motors \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
|
||||
config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
@@ -177,7 +177,7 @@ lerobot-setup-motors \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
|
||||
config = SO100LeaderConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
@@ -579,7 +579,7 @@ lerobot-calibrate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig, SO100Follower
|
||||
|
||||
config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
@@ -617,7 +617,7 @@ lerobot-calibrate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
|
||||
from lerobot.teleoperators.so_leader import SO100LeaderConfig, SO100Leader
|
||||
|
||||
config = SO100LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
|
||||
@@ -125,7 +125,7 @@ lerobot-setup-motors \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
@@ -201,7 +201,7 @@ lerobot-setup-motors \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
|
||||
config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
@@ -364,7 +364,7 @@ lerobot-calibrate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
|
||||
from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower
|
||||
|
||||
config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
@@ -413,7 +413,7 @@ lerobot-calibrate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
|
||||
from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader
|
||||
|
||||
config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
# Unitree G1 Robot Setup and Control
|
||||
# Unitree G1
|
||||
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
|
||||
## About the Unitree G1
|
||||
## About
|
||||
|
||||
We offer support for both 29 and 23 DOF G1. We introduce:
|
||||
We support both 29 and 23 DOF G1 EDU version. We introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||
- **GR00T locomotion policy** for bipedal walking and balance
|
||||
- **MuJoCo simulation mode** for testing policies without the physical robot
|
||||
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
|
||||
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
|
||||
- **Simulation mode** for testing policies without the physical robot in mujoco
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Connect to Robot over Ethernet
|
||||
## Connection guide
|
||||
|
||||
### Step 1: Configure Your Computer's Ethernet Interface
|
||||
### Step 1: Configure Ethernet Interface
|
||||
|
||||
Set a static IP on the same subnet as the robot:
|
||||
|
||||
@@ -26,7 +26,7 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
|
||||
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
|
||||
@@ -35,25 +35,24 @@ ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
You should now be connected to the robot's onboard computer.
|
||||
You should now be connected to the G1's Orin.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Once connected via Ethernet, follow these steps to enable WiFi:
|
||||
Wlan0 is disabled by default on the G1. To enable it:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
|
||||
```bash
|
||||
# Unblock WiFi radio
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up WiFi interface
|
||||
# Bring up wlan0
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control
|
||||
# Enable NetworkManager control of wlan0
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
@@ -73,7 +72,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the robot:**
|
||||
**On the G1:**
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
@@ -111,7 +110,7 @@ ssh unitree@<YOUR_ROBOT_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
|
||||
|
||||
---
|
||||
|
||||
@@ -147,9 +146,9 @@ python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Running GR00T Locomotion
|
||||
## Part 4: Controlling the robot
|
||||
|
||||
With the robot server running, you can now control the robot from your laptop.
|
||||
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
|
||||
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
@@ -172,34 +171,30 @@ Edit the config file to match your robot's WiFi IP:
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
# Run Holosoma locomotion controller
|
||||
python examples/unitree_g1/holosoma_locomotion.py
|
||||
|
||||
```
|
||||
|
||||
### Step 4: Control with Remote
|
||||
|
||||
- **Left stick**: Forward/backward and left/right movement
|
||||
- **Right stick**: Rotation
|
||||
- **R1 button**: Raise waist height
|
||||
- **R2 button**: Lower waist height
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Extra: Running in Simulation Mode (MuJoCo)
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
|
||||
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
|
||||
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
|
||||
- [Holosoma](https://github.com/amazon-far/holosoma)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
|
||||
@@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
@@ -124,16 +124,23 @@ lerobot-edit-dataset \
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
|
||||
# For memory-constrained systems, users can now specify limits:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.max_episodes_per_batch 50 \
|
||||
--operation.max_frames_per_batch 10000
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
80
docs/source/walloss.mdx
Normal file
80
docs/source/walloss.mdx
Normal file
@@ -0,0 +1,80 @@
|
||||
# WALL-OSS
|
||||
|
||||
WALL-OSS is an open-source foundation model for embodied intelligence, proposed by the [XSquare Robot](https://x2robot.com/en/research/68bc2cde8497d7f238dde690) team in 2025. The LeRobot implementation is adapted from their open-source [WallX](https://github.com/X-Square-Robot/wall-x) repository.
|
||||
|
||||
X Square Robot’s WALL-OSS is now integrated into Hugging Face’s LeRobot ecosystem. This is an exciting collaborative project between the LeRobot and X Square Robot teams. You can now post-train, evaluate, and deploy WALL-OSS directly through LeRobot. With this, we’re aiming to make it easier for the open-source robotics community to customize and deploy WALL-OSS foundation models. Read and explore WALL-OSS [paper](https://arxiv.org/pdf/2509.11766) and [code](https://github.com/X-Square-Robot/wall-x).
|
||||
|
||||
## Model Overview
|
||||
|
||||
The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png"
|
||||
alt="An overview of WALL-OSS"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include:
|
||||
|
||||
- **Embodied perception–enhanced multimodal pretraining**: Large-scale training on unified vision–language–action data to strengthen spatial, causal, and manipulation understanding.
|
||||
- **Unified Cross-Level Chain-of-Thought (Uni-CoT)**: A single differentiable framework that unifies high-level instruction reasoning, sub-task decomposition, and fine-grained action synthesis, forming a continuous chain from “understanding” to “execution.”
|
||||
- **Mixture-of-Experts (MoE) action heads**: Dynamically activating experts depending on the task phase and modeling actions in discrete or continuous space to maintain stable VLM priors.
|
||||
- **Two-stage training paradigm**:
|
||||
- **Inspiration stage**: Injecting discrete action priors to strengthen spatial understanding and semantic-action alignment.
|
||||
- **Integration stage**: Using flow matching to achieve high-frequency continuous control.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install WallX dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[wallx]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use WallX in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=wall_x
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
--job_name=wallx_training \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.pretrained_name_or_path=x-square-robot/wall-oss-flow \
|
||||
--policy.prediction_mode=diffusion \
|
||||
--policy.attn_implementation=eager \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--dataset.repo_id` | The Hugging Face Hub repository ID for your training dataset (e.g., `lerobot/aloha_sim_insertion_human`) |
|
||||
| `--policy.type` | Specifies using the WallX policy architecture |
|
||||
| `--output_dir` | Local directory where training checkpoints and logs will be saved |
|
||||
| `--job_name` | A name identifier for this training run (used in logging/tracking) |
|
||||
| `--policy.repo_id` | Your Hugging Face Hub repo ID where the trained model will be pushed |
|
||||
| `--policy.pretrained_path` | Path to pretrained WallX weights to initialize from (the official WALL-OSS checkpoint) |
|
||||
| `--policy.prediction_mode` | The action prediction strategy: `diffusion` or `fast` - `diffusion` uses iterative denoising for action generation, `fast` uses next token prediction instead |
|
||||
| `--policy.attn_implementation` | Attention implementation backend - `eager` uses standard PyTorch attention (alternatives include `flash_attention_2` or `sdpa`) |
|
||||
| `--steps` | Total number of training steps to run |
|
||||
| `--policy.device` | Device to train on (`cuda` for GPU, `cpu` for CPU) |
|
||||
| `--batch_size` | Number of samples per training batch |
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x).
|
||||
@@ -41,8 +41,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
@@ -97,7 +96,7 @@ def replay(cfg: ReplayConfig):
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -18,7 +18,7 @@ import time
|
||||
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -43,12 +43,13 @@ def main():
|
||||
keyboard.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -69,7 +70,7 @@ def main():
|
||||
_ = robot.send_action(action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -34,12 +34,11 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -26,15 +26,14 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
|
||||
@@ -23,11 +23,10 @@ from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
@@ -96,7 +95,7 @@ def main():
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
@@ -21,14 +21,13 @@ from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
@@ -90,12 +89,13 @@ def main():
|
||||
teleop_device.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -112,7 +112,7 @@ def main():
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -94,9 +94,9 @@ from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
@@ -455,7 +455,18 @@ def demo_cli(cfg: RTCDemoConfig):
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
if config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_pretrained_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
@@ -34,12 +34,11 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -27,16 +27,14 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -24,11 +24,10 @@ from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
@@ -97,7 +96,7 @@ def main():
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
@@ -23,15 +23,13 @@ from lerobot.processor.converters import (
|
||||
robot_action_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -96,9 +94,10 @@ def main():
|
||||
leader.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -118,7 +117,9 @@ def main():
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
log_rerun_data(
|
||||
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
|
||||
)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
@@ -4,7 +4,7 @@ from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -5,8 +5,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
@@ -5,8 +5,7 @@ from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
@@ -14,8 +14,8 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
|
||||
@@ -5,8 +5,7 @@ from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
@@ -13,16 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: GR00T Locomotion with Pre-loaded Policies
|
||||
|
||||
This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
@@ -31,24 +24,26 @@ import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
|
||||
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # or "g1_29"
|
||||
G1_MODEL = "g1_23" # Or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
|
||||
LOCOMOTION_CONTROL_DT = 0.02
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
# Control parameters
|
||||
ACTION_SCALE = 0.25
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
ANG_VEL_SCALE: float = 0.25
|
||||
DOF_POS_SCALE: float = 1.0
|
||||
DOF_VEL_SCALE: float = 0.05
|
||||
@@ -61,12 +56,12 @@ DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
def load_groot_policies(
|
||||
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from the hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||
"""
|
||||
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||
logger.info(f"Loading GR00T dual-policy system from the hub ({repo_id})...")
|
||||
|
||||
# Download ONNX policies from Hugging Face Hub
|
||||
balance_path = hf_hub_download(
|
||||
@@ -88,15 +83,7 @@ def load_groot_policies(
|
||||
|
||||
|
||||
class GrootLocomotionController:
|
||||
"""
|
||||
Handles GR00T-style locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Dual-policy system (Balance + Walk)
|
||||
- 29-joint observation processing
|
||||
- 15D action output (legs + waist)
|
||||
- Policy inference and motor command generation
|
||||
"""
|
||||
"""GR00T lower-body locomotion controller for the Unitree G1."""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
self.policy_balance = policy_balance
|
||||
@@ -104,9 +91,9 @@ class GrootLocomotionController:
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
|
||||
# GR00T-specific state
|
||||
# Robot state
|
||||
self.groot_qj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_action = np.zeros(15, dtype=np.float32)
|
||||
@@ -116,47 +103,39 @@ class GrootLocomotionController:
|
||||
self.groot_height_cmd = 0.74 # Default base height
|
||||
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# input to gr00t is 6 frames (6*86D=516)
|
||||
# Input to GR00T is 6 frames (6*86D=516)
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
# get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
if not obs:
|
||||
return
|
||||
|
||||
# get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
# Get command from remote controller
|
||||
if obs["remote.buttons"][0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if obs["remote.buttons"][4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||
self.cmd[0] = obs["remote.ly"] # Forward/backward
|
||||
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
|
||||
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
|
||||
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
# Get joint positions and velocities from flat dict
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.groot_qj_all[idx] = obs[f"{name}.q"]
|
||||
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# adapt observation for g1_23dof
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
@@ -165,18 +144,18 @@ class GrootLocomotionController:
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# express imu data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# scale joint positions and velocities before policy inference
|
||||
# Scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# build single frame observation
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||
# Build single frame observation
|
||||
self.groot_obs_single[:3] = self.cmd * np.array(CMD_SCALE)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
self.groot_obs_single[4:7] = self.groot_orientation_cmd
|
||||
self.groot_obs_single[7:10] = ang_vel_scaled
|
||||
@@ -194,113 +173,76 @@ class GrootLocomotionController:
|
||||
end_idx = start_idx + 86
|
||||
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
|
||||
cmd_magnitude = np.linalg.norm(self.cmd)
|
||||
selected_policy = (
|
||||
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
|
||||
) # balance/standing policy for small commands, walking policy for movement commands
|
||||
) # Balance/standing policy for small commands, walking policy for movement commands
|
||||
|
||||
# run policy inference
|
||||
# Run policy inference
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||
# Transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE
|
||||
|
||||
# command motors
|
||||
# Build action dict (only first 15 joints for GR00T)
|
||||
action_dict = {}
|
||||
for i in range(15):
|
||||
motor_idx = i
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
motor_name = G1_29_JointIndex(i).name
|
||||
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i])
|
||||
|
||||
# adapt action for g1_23dof
|
||||
# Zero out missing joints for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
self.robot.msg.motor_cmd[joint_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[joint_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||
motor_name = G1_29_JointIndex(joint_idx).name
|
||||
action_dict[f"{motor_name}.q"] = 0.0
|
||||
|
||||
# send action to robot
|
||||
self.robot.send_action(self.robot.msg)
|
||||
# Send action to robot
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
|
||||
def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
"""Main function to run the GR00T locomotion controller.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID for GR00T policies.
|
||||
"""
|
||||
# Load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=repo_id)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
robot.connect()
|
||||
|
||||
# Initialize gr00T locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
try:
|
||||
robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES)
|
||||
|
||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.groot_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
groot_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Only control legs, not arms (first 12 joints)
|
||||
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
|
||||
dof_size = len(default_pos)
|
||||
|
||||
# Get current lowstate
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record the current leg positions
|
||||
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
|
||||
for i in range(dof_size):
|
||||
init_dof_pos[i] = robot_state.motor_state[i].q
|
||||
|
||||
# Move legs to default pos
|
||||
for i in range(num_step):
|
||||
alpha = i / num_step
|
||||
for motor_idx in range(dof_size):
|
||||
target_pos = default_pos[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
logger.info("Reached default position (legs only)")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping locomotion...")
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -313,35 +255,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||
|
||||
# initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# initialize gr00t locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# reset legs and start locomotion thread
|
||||
try:
|
||||
groot_controller.reset_robot()
|
||||
groot_controller.start_locomotion_thread()
|
||||
|
||||
# log status
|
||||
logger.info("Robot initialized with GR00T locomotion policies")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
groot_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
run(repo_id=args.repo_id)
|
||||
|
||||
264
examples/unitree_g1/holosoma_locomotion.py
Normal file
264
examples/unitree_g1/holosoma_locomotion.py
Normal file
@@ -0,0 +1,264 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
DEFAULT_ANGLES[[0, 6]] = -0.312 # Hip pitch
|
||||
DEFAULT_ANGLES[[3, 9]] = 0.669 # Knee
|
||||
DEFAULT_ANGLES[[4, 10]] = -0.363 # Ankle pitch
|
||||
DEFAULT_ANGLES[[15, 22]] = 0.2 # Shoulder pitch
|
||||
DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll
|
||||
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
|
||||
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # Or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
# Control parameters
|
||||
ACTION_SCALE = 0.25
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
|
||||
# Policy filename mapping
|
||||
POLICY_FILES = {
|
||||
"fastsac": "fastsac_g1_29dof.onnx",
|
||||
"ppo": "ppo_g1_29dof.onnx",
|
||||
}
|
||||
|
||||
|
||||
def load_policy(
|
||||
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
||||
policy_type: str = "fastsac",
|
||||
) -> tuple[ort.InferenceSession, np.ndarray, np.ndarray]:
|
||||
"""Load Holosoma locomotion policy and extract KP/KD from metadata.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repo ID
|
||||
policy_type: Either "fastsac" (default) or "ppo"
|
||||
|
||||
Returns:
|
||||
(policy, kp, kd) tuple
|
||||
"""
|
||||
if policy_type not in POLICY_FILES:
|
||||
raise ValueError(f"Unknown policy type: {policy_type}. Choose from: {list(POLICY_FILES.keys())}")
|
||||
|
||||
filename = POLICY_FILES[policy_type]
|
||||
logger.info(f"Loading {policy_type.upper()} policy from: {repo_id}/{filename}")
|
||||
policy_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
|
||||
policy = ort.InferenceSession(policy_path)
|
||||
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}")
|
||||
|
||||
# Extract KP/KD from ONNX metadata
|
||||
model = onnx.load(policy_path)
|
||||
metadata = {prop.key: prop.value for prop in model.metadata_props}
|
||||
|
||||
if "kp" not in metadata or "kd" not in metadata:
|
||||
raise ValueError("ONNX model must contain 'kp' and 'kd' in metadata")
|
||||
|
||||
kp = np.array(json.loads(metadata["kp"]), dtype=np.float32)
|
||||
kd = np.array(json.loads(metadata["kd"]), dtype=np.float32)
|
||||
logger.info(f"Loaded KP/KD from ONNX ({len(kp)} joints)")
|
||||
|
||||
return policy, kp, kd
|
||||
|
||||
|
||||
class HolosomaLocomotionController:
|
||||
"""Holosoma whole-body locomotion controller for Unitree G1."""
|
||||
|
||||
def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
|
||||
# Override robot's PD gains with policy gains
|
||||
self.robot.kp = kp
|
||||
self.robot.kd = kd
|
||||
|
||||
self.cmd = np.zeros(3, dtype=np.float32)
|
||||
|
||||
# Robot state
|
||||
self.qj = np.zeros(29, dtype=np.float32)
|
||||
self.dqj = np.zeros(29, dtype=np.float32)
|
||||
self.obs = np.zeros(100, dtype=np.float32)
|
||||
self.last_action = np.zeros(29, dtype=np.float32)
|
||||
|
||||
# Gait phase
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
|
||||
self.is_standing = True
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
if not obs:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
|
||||
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
|
||||
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
|
||||
self.cmd[:] = [ly, -lx, -rx]
|
||||
|
||||
# Get joint positions and velocities
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.qj[idx] = obs[f"{name}.q"]
|
||||
self.dqj[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.qj[idx] = 0.0
|
||||
self.dqj[idx] = 0.0
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_s = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Update gait phase
|
||||
if np.linalg.norm(self.cmd[:2]) < 0.01 and abs(self.cmd[2]) < 0.01:
|
||||
self.phase[0, :] = np.pi
|
||||
self.is_standing = True
|
||||
elif self.is_standing:
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = False
|
||||
else:
|
||||
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sin_ph = np.sin(self.phase[0])
|
||||
cos_ph = np.cos(self.phase[0])
|
||||
|
||||
# Build observations
|
||||
self.obs[0:29] = self.last_action
|
||||
self.obs[29:32] = ang_vel_s
|
||||
self.obs[32] = self.cmd[2]
|
||||
self.obs[33:35] = self.cmd[:2]
|
||||
self.obs[35:37] = cos_ph
|
||||
self.obs[37:66] = qj_obs
|
||||
self.obs[66:95] = dqj_obs
|
||||
self.obs[95:98] = gravity
|
||||
self.obs[98:100] = sin_ph
|
||||
|
||||
# Run policy inference
|
||||
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
|
||||
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||
action = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = action.copy()
|
||||
|
||||
# Transform action back to target joint positions
|
||||
target = DEFAULT_ANGLES + action * ACTION_SCALE
|
||||
|
||||
# Build action dict
|
||||
action_dict = {}
|
||||
for motor in G1_29_JointIndex:
|
||||
action_dict[f"{motor.name}.q"] = float(target[motor.value])
|
||||
|
||||
# Zero out missing joints for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
motor_name = G1_29_JointIndex(joint_idx).name
|
||||
action_dict[f"{motor_name}.q"] = 0.0
|
||||
|
||||
# Send action to robot
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
|
||||
def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None:
|
||||
"""Main function to run the Holosoma locomotion controller.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID for Holosoma policies.
|
||||
policy_type: Policy type to use ('fastsac' or 'ppo').
|
||||
"""
|
||||
# Load policy and gains
|
||||
policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
robot.connect()
|
||||
|
||||
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
|
||||
|
||||
try:
|
||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
|
||||
|
||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
holosoma_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping locomotion...")
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_HOLOSOMA_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy",
|
||||
type=str,
|
||||
choices=["fastsac", "ppo"],
|
||||
default="fastsac",
|
||||
help="Policy type to use: 'fastsac' (default) or 'ppo'",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(repo_id=args.repo_id, policy_type=args.policy)
|
||||
@@ -27,7 +27,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
name = "lerobot"
|
||||
version = "0.4.3"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
requires-python = ">=3.10"
|
||||
authors = [
|
||||
@@ -74,7 +74,7 @@ dependencies = [
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -97,7 +97,7 @@ dependencies = [
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -109,9 +109,9 @@ hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0"
|
||||
"onnxruntime>=1.16.0,<2.0.0"
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
@@ -127,7 +127,7 @@ wallx = [
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -140,12 +140,14 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -168,12 +170,13 @@ all = [
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
# "lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[audio]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
"lerobot[video_benchmark]",
|
||||
@@ -182,7 +185,8 @@ all = [
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]"
|
||||
"lerobot[sarm]",
|
||||
"lerobot[peft]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -195,6 +199,7 @@ lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
||||
lerobot-train="lerobot.scripts.lerobot_train:main"
|
||||
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
|
||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
@@ -405,6 +410,10 @@ conflicts = [
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
@@ -415,6 +424,47 @@ conflicts = [
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
72
setup.py
Normal file
72
setup.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
def get_version_from_toml() -> str:
|
||||
"""Return the project's version string parsed from `pyproject.toml`.
|
||||
|
||||
The function scans `pyproject.toml` line-by-line looking for a line
|
||||
that starts with ``version`` (for example: ``version = "1.2.3"``)
|
||||
and returns the value without surrounding quotes. If no such line is
|
||||
found a :class:`ValueError` is raised.
|
||||
|
||||
Returns:
|
||||
The version string from `pyproject.toml` (e.g. ``"1.2.3"`` ->
|
||||
``1.2.3``).
|
||||
"""
|
||||
|
||||
version = None
|
||||
with open("pyproject.toml", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("version"):
|
||||
version = line.split("=")[1].strip().strip('"')
|
||||
break
|
||||
if version is None:
|
||||
raise ValueError("Version not found in pyproject.toml")
|
||||
return version
|
||||
|
||||
|
||||
def read_long_description() -> str:
|
||||
"""Read and return the project's long description for setup.
|
||||
|
||||
This function reads `README.md` and replaces image links that point
|
||||
to the local `./media/` directory with absolute raw GitHub URLs that
|
||||
reference the release tag corresponding to the version parsed from
|
||||
`pyproject.toml` (for example, ``v1.2.3``). The modified README
|
||||
content is returned as a string suitable for passing to
|
||||
``setuptools.setup(long_description=...)``.
|
||||
|
||||
Returns:
|
||||
The README content with rewritten media links.
|
||||
"""
|
||||
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
version = get_version_from_toml()
|
||||
git_tag = f"v{version}"
|
||||
|
||||
base_raw_url = f"https://raw.githubusercontent.com/huggingface/lerobot/{git_tag}/"
|
||||
content = content.replace('src="./media/', f'src="{base_raw_url}media/')
|
||||
|
||||
return content
|
||||
|
||||
|
||||
setup(
|
||||
long_description=read_long_description(),
|
||||
long_description_content_type="text/markdown",
|
||||
)
|
||||
@@ -29,6 +29,7 @@ Example:
|
||||
print(lerobot.available_policies_per_env)
|
||||
print(lerobot.available_robots)
|
||||
print(lerobot.available_cameras)
|
||||
print(lerobot.available_microphones)
|
||||
print(lerobot.available_motors)
|
||||
```
|
||||
|
||||
@@ -174,6 +175,13 @@ available_cameras = [
|
||||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available microphones from `lerobot/microphones`
|
||||
available_microphones = [
|
||||
"portaudio",
|
||||
"touchlab",
|
||||
"anyskin",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
|
||||
@@ -40,7 +40,6 @@ from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
@@ -48,15 +47,18 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
@@ -352,7 +354,7 @@ class RobotClient:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
def control_loop_action(self, verbose: bool = False) -> RobotAction:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
|
||||
@@ -35,18 +35,19 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address="192.168.0.200", # IP address of the robot
|
||||
fps=15,
|
||||
port=50065, # Port of the camera server
|
||||
width=640,
|
||||
height=480,
|
||||
fps=30, # Not configurable for Reachy 2 cameras
|
||||
color_mode=ColorMode.RGB,
|
||||
) # Left teleop camera, 640x480 @ 15FPS
|
||||
) # Left teleop camera, 640x480 @ 30FPS
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: Name of the camera device. Can be "teleop" or "depth".
|
||||
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
|
||||
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
|
||||
fps: Requested frames per second for the color stream.
|
||||
fps: Requested frames per second for the color stream. Not configurable for Reachy 2 cameras.
|
||||
width: Requested frame width in pixels for the color stream.
|
||||
height: Requested frame height in pixels for the color stream.
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
@@ -62,7 +63,6 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
ip_address: str | None = "localhost"
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
|
||||
@@ -16,12 +16,13 @@
|
||||
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
@@ -30,10 +31,19 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
CameraManager,
|
||||
)
|
||||
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
else:
|
||||
CameraManager = None
|
||||
|
||||
class CameraView:
|
||||
LEFT = 0
|
||||
RIGHT = 1
|
||||
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -69,17 +79,10 @@ class Reachy2Camera(Camera):
|
||||
|
||||
self.config = config
|
||||
|
||||
self.fps = config.fps
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
|
||||
|
||||
@@ -100,44 +103,23 @@ class Reachy2Camera(Camera):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"Could not connect to {self}.")
|
||||
self.cam_manager.initialize_cameras()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available Reachy 2 cameras.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains 'name', 'stereo',
|
||||
and the default profile properties (width, height, fps).
|
||||
Detection not implemented for Reachy2 cameras.
|
||||
"""
|
||||
initialized_cameras = []
|
||||
camera_manager = CameraManager(host=ip_address, port=port)
|
||||
|
||||
for camera in [camera_manager.teleop, camera_manager.depth]:
|
||||
if camera is None:
|
||||
continue
|
||||
|
||||
height, width, _, _, _, _, _ = camera.get_parameters()
|
||||
|
||||
camera_info = {
|
||||
"name": camera._cam_info.name,
|
||||
"stereo": camera._cam_info.stereo,
|
||||
"default_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"fps": 30,
|
||||
},
|
||||
}
|
||||
initialized_cameras.append(camera_info)
|
||||
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -155,95 +137,49 @@ class Reachy2Camera(Camera):
|
||||
(height, width, channels), using the specified or default
|
||||
color mode and applying any configured rotation.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(
|
||||
CameraView.LEFT, size=(self.config.width, self.config.height)
|
||||
)[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(
|
||||
CameraView.RIGHT, size=(self.config.width, self.config.height)
|
||||
)[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(self.config.width, self.config.height))[0]
|
||||
else:
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
Reads the latest available frame.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -261,22 +197,10 @@ class Reachy2Camera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
frame = self.read()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
@@ -287,12 +211,9 @@ class Reachy2Camera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
|
||||
@@ -43,6 +43,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
elif cfg.type == "zmq":
|
||||
from .zmq.camera_zmq import ZMQCamera
|
||||
|
||||
cameras[key] = ZMQCamera(cfg)
|
||||
|
||||
else:
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,5 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_so100_follower import BiSO100Follower
|
||||
from .config_bi_so100_follower import BiSO100FollowerConfig
|
||||
from .camera_zmq import ZMQCamera
|
||||
from .configuration_zmq import ZMQCameraConfig
|
||||
|
||||
__all__ = ["ZMQCamera", "ZMQCameraConfig"]
|
||||
235
src/lerobot/cameras/zmq/camera_zmq.py
Normal file
235
src/lerobot/cameras/zmq/camera_zmq.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
ZMQCamera - Captures frames from remote cameras via ZeroMQ using JSON protocol in the
|
||||
following format:
|
||||
{
|
||||
"timestamps": {"camera_name": float},
|
||||
"images": {"camera_name": "<base64-jpeg>"}
|
||||
}
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
from .configuration_zmq import ZMQCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Example usage:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
|
||||
|
||||
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
frame = camera.read()
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZMQCameraConfig):
|
||||
super().__init__(config)
|
||||
import zmq
|
||||
|
||||
self.config = config
|
||||
self.server_address = config.server_address
|
||||
self.port = config.port
|
||||
self.camera_name = config.camera_name
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ZMQCamera({self.camera_name}@{self.server_address}:{self.port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server."""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
try:
|
||||
import zmq
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||
self.socket.setsockopt(zmq.CONFLATE, True)
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self._connected = True
|
||||
|
||||
# Auto-detect resolution
|
||||
if self.width is None or self.height is None:
|
||||
h, w = self.read().shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} resolution: {w}x{h}")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
if warmup:
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
raise RuntimeError(f"Failed to connect to {self}: {e}") from e
|
||||
|
||||
def _cleanup(self):
|
||||
"""Clean up ZMQ resources."""
|
||||
self._connected = False
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
if self.context:
|
||||
self.context.term()
|
||||
self.context = None
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""ZMQ cameras require manual configuration (server address/port)."""
|
||||
return []
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Read a single frame from the ZMQ camera.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
if not self.is_connected or self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
|
||||
# Decode JSON message
|
||||
data = json.loads(message)
|
||||
|
||||
if "images" not in data:
|
||||
raise RuntimeError(f"{self} invalid message: missing 'images' key")
|
||||
|
||||
images = data["images"]
|
||||
|
||||
# Get image by camera name or first available
|
||||
if self.camera_name in images:
|
||||
img_b64 = images[self.camera_name]
|
||||
elif images:
|
||||
img_b64 = next(iter(images.values()))
|
||||
else:
|
||||
raise RuntimeError(f"{self} no images in message")
|
||||
|
||||
# Decode base64 JPEG
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
frame = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"{self} failed to decode image")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
while self.stop_event and not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self.read()
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.new_frame_event.set()
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Read error: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event:
|
||||
self.stop_event.set()
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""Read latest frame asynchronously (non-blocking)."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if not self.thread or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"{self} no frame available")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from ZMQ camera."""
|
||||
if not self.is_connected and not self.thread:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
self._stop_read_thread()
|
||||
self._cleanup()
|
||||
logger.info(f"{self} disconnected.")
|
||||
46
src/lerobot/cameras/zmq/configuration_zmq.py
Normal file
46
src/lerobot/cameras/zmq/configuration_zmq.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["ZMQCameraConfig", "ColorMode"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("zmq")
|
||||
@dataclass
|
||||
class ZMQCameraConfig(CameraConfig):
|
||||
server_address: str
|
||||
port: int = 5555
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.timeout_ms <= 0:
|
||||
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
|
||||
|
||||
if not self.server_address:
|
||||
raise ValueError("`server_address` cannot be empty.")
|
||||
|
||||
if self.port <= 0 or self.port > 65535:
|
||||
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")
|
||||
114
src/lerobot/cameras/zmq/image_server.py
Normal file
114
src/lerobot/cameras/zmq/image_server.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Streams camera images over ZMQ.
|
||||
Uses lerobot's OpenCVCamera for capture, encodes images to base64 and sends them over ZMQ.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from lerobot.cameras.configs import ColorMode
|
||||
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def encode_image(image: np.ndarray, quality: int = 80) -> str:
|
||||
"""Encode RGB image to base64 JPEG string."""
|
||||
_, buffer = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
class ImageServer:
|
||||
def __init__(self, config: dict, port: int = 5555):
|
||||
self.fps = config.get("fps", 30)
|
||||
self.cameras: dict[str, OpenCVCamera] = {}
|
||||
|
||||
for name, cfg in config.get("cameras", {}).items():
|
||||
shape = cfg.get("shape", [480, 640])
|
||||
cam_config = OpenCVCameraConfig(
|
||||
index_or_path=cfg.get("device_id", 0),
|
||||
fps=self.fps,
|
||||
width=shape[1],
|
||||
height=shape[0],
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
camera = OpenCVCamera(cam_config)
|
||||
camera.connect()
|
||||
self.cameras[name] = camera
|
||||
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
||||
|
||||
# ZMQ PUB socket
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.setsockopt(zmq.SNDHWM, 20)
|
||||
self.socket.setsockopt(zmq.LINGER, 0)
|
||||
self.socket.bind(f"tcp://*:{port}")
|
||||
|
||||
logger.info(f"ImageServer running on port {port}")
|
||||
|
||||
def run(self):
|
||||
frame_count = 0
|
||||
frame_times = deque(maxlen=60)
|
||||
|
||||
try:
|
||||
while True:
|
||||
t0 = time.time()
|
||||
|
||||
# Build message
|
||||
message = {"timestamps": {}, "images": {}}
|
||||
for name, cam in self.cameras.items():
|
||||
frame = cam.read() # Returns RGB
|
||||
message["timestamps"][name] = time.time()
|
||||
message["images"][name] = encode_image(frame)
|
||||
|
||||
# Send as JSON string (suppress if buffer full)
|
||||
with contextlib.suppress(zmq.Again):
|
||||
self.socket.send_string(json.dumps(message), zmq.NOBLOCK)
|
||||
|
||||
frame_count += 1
|
||||
frame_times.append(time.time() - t0)
|
||||
|
||||
if frame_count % 60 == 0:
|
||||
logger.debug(f"FPS: {len(frame_times) / sum(frame_times):.1f}")
|
||||
|
||||
sleep = (1.0 / self.fps) - (time.time() - t0)
|
||||
if sleep > 0:
|
||||
time.sleep(sleep)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
config = {"fps": 30, "cameras": {"head_camera": {"device_id": 4, "shape": [480, 640]}}}
|
||||
ImageServer(config, port=5555).run()
|
||||
@@ -67,3 +67,31 @@ class EvalConfig:
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeftConfig:
|
||||
# PEFT offers many fine-tuning methods, layer adapters being the most common and currently also the most
|
||||
# effective methods so we'll focus on those in this high-level config interface.
|
||||
|
||||
# Either a string (module name suffix or 'all-linear'), a list of module name suffixes or a regular expression
|
||||
# describing module names to target with the configured PEFT method. Some policies have a default value for this
|
||||
# so that you don't *have* to choose which layers to adapt but it might still be worthwhile depending on your case.
|
||||
target_modules: list[str] | str | None = None
|
||||
|
||||
# Names/suffixes of modules to fully fine-tune and store alongside adapter weights. Useful for layers that are
|
||||
# not part of a pre-trained model (e.g., action state projections). Depending on the policy this defaults to layers
|
||||
# that are newly created in pre-trained policies. If you're fine-tuning an already trained policy you might want
|
||||
# to set this to `[]`. Corresponds to PEFT's `modules_to_save`.
|
||||
full_training_modules: list[str] | None = None
|
||||
|
||||
# The PEFT (adapter) method to apply to the policy. Needs to be a valid PEFT type.
|
||||
method_type: str = "LORA"
|
||||
|
||||
# Adapter initialization method. Look at the specific PEFT adapter documentation for defaults.
|
||||
init_type: str | None = None
|
||||
|
||||
# We expect that all PEFT adapters are in some way doing rank-decomposition therefore this parameter specifies
|
||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||
# fine-tuning.
|
||||
r: int = 16
|
||||
|
||||
@@ -38,6 +38,8 @@ class EvalPipelineConfig:
|
||||
seed: int | None = 1000
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
# Explicit consent to execute remote code from the Hub (required for hub environments).
|
||||
trust_remote_code: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -55,14 +55,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
# `input_features` can be set to None/null in order to infer those values from the dataset.
|
||||
input_features: dict[str, PolicyFeature] | None = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] | None = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
# Whether the policy employed PEFT for training.
|
||||
use_peft: bool = False
|
||||
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -125,6 +129,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
@property
|
||||
def robot_state_feature(self) -> PolicyFeature | None:
|
||||
if not self.input_features:
|
||||
return None
|
||||
for ft_name, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
|
||||
return ft
|
||||
@@ -132,6 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
@property
|
||||
def env_state_feature(self) -> PolicyFeature | None:
|
||||
if not self.input_features:
|
||||
return None
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.ENV:
|
||||
return ft
|
||||
@@ -139,10 +147,20 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
if not self.input_features:
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||
if not self.input_features:
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
if not self.output_features:
|
||||
return None
|
||||
for ft_name, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION and ft_name == ACTION:
|
||||
return ft
|
||||
|
||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -65,6 +65,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
|
||||
@@ -20,6 +20,7 @@ from enum import Enum
|
||||
class FeatureType(str, Enum):
|
||||
STATE = "STATE"
|
||||
VISUAL = "VISUAL"
|
||||
AUDIO = "AUDIO"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
@@ -19,12 +19,15 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -32,6 +35,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
@@ -39,7 +43,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -108,6 +112,7 @@ def update_meta_data(
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
@@ -120,7 +125,7 @@ def update_meta_data(
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
|
||||
audios_idx: Dictionary containing current audio indices and timestamps.
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
@@ -178,6 +183,36 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
# Store original audio file indices before updating
|
||||
orig_chunk_col = f"audio/{key}/chunk_index"
|
||||
orig_file_col = f"audio/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = audio_idx["chunk"]
|
||||
df[orig_file_col] = audio_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = audio_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"audio/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"audio/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"audio/{key}/from_timestamp"] = (
|
||||
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
|
||||
)
|
||||
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
@@ -192,6 +227,7 @@ def aggregate_datasets(
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
audio_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -209,6 +245,7 @@ def aggregate_datasets(
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
@@ -217,6 +254,8 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_files_size_in_mb is None:
|
||||
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
@@ -229,6 +268,7 @@ def aggregate_datasets(
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
@@ -240,6 +280,7 @@ def aggregate_datasets(
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
audio_files_size_in_mb=audio_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -251,14 +292,18 @@ def aggregate_datasets(
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
audios_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
@@ -326,7 +371,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="video")
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
@@ -365,7 +410,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
@@ -380,6 +425,101 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
|
||||
"""Aggregates audio files from a source dataset into the destination dataset.
|
||||
|
||||
Handles audio file concatenation and rotation based on file size limits.
|
||||
Creates new audio files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
audio_idx: Dictionary tracking audio chunk and file indices.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated audio_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in audios_idx:
|
||||
audios_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
audios_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"audio/{key}/chunk_index"],
|
||||
src_meta.episodes[f"audio/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||
|
||||
chunk_idx = audio_idx["chunk"]
|
||||
file_idx = audio_idx["file"]
|
||||
current_offset = audio_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="audio")
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= audio_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
audios_idx[key]["chunk"] = chunk_idx
|
||||
audios_idx[key]["file"] = file_idx
|
||||
|
||||
return audios_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
@@ -402,12 +542,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
}
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
contains_images = len(dst_meta.image_keys) > 0
|
||||
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read source data to preserve image format
|
||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
||||
df = src_ds.to_pandas()
|
||||
else:
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
@@ -417,14 +566,15 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
data_files_size_in_mb,
|
||||
chunk_size,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
@@ -436,6 +586,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
audios_idx: Dictionary tracking audio indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
@@ -459,6 +610,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
@@ -475,7 +627,8 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
for k in audios_idx:
|
||||
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
|
||||
return meta_idx
|
||||
|
||||
|
||||
@@ -488,6 +641,7 @@ def append_or_create_parquet_file(
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
):
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -503,6 +657,7 @@ def append_or_create_parquet_file(
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
@@ -512,7 +667,7 @@ def append_or_create_parquet_file(
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
@@ -527,12 +682,17 @@ def append_or_create_parquet_file(
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read existing data to preserve image format
|
||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
||||
existing_df = existing_ds.to_pandas()
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
275
src/lerobot/datasets/audio_utils.py
Normal file
275
src/lerobot/datasets/audio_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchcodec
|
||||
from numpy import ceil
|
||||
|
||||
CHANNELS_LAYOUTS_MAPPING = {
|
||||
1: "mono",
|
||||
2: "stereo",
|
||||
3: "2.1",
|
||||
4: "3.1",
|
||||
5: "4.1",
|
||||
6: "5.1",
|
||||
7: "6.1",
|
||||
8: "7.1",
|
||||
16: "hexadecagonal",
|
||||
24: "22.2",
|
||||
}
|
||||
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
backend: str | None = "torchcodec",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes audio using the specified backend.
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file.
|
||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||
duration (float): Duration of the audio chunks in seconds.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded audio chunks.
|
||||
|
||||
Currently supports torchaudio.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
|
||||
elif backend == "torchaudio":
|
||||
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_audio_torchcodec(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
|
||||
audio_sample_rate = audio_decoder.metadata.sample_rate
|
||||
audio_channels = audio_decoder.metadata.num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
current_audio_chunk = audio_decoder.get_samples_played_in_range(
|
||||
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
||||
)
|
||||
|
||||
current_audio_chunk_data = current_audio_chunk.data
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def decode_audio_torchaudio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
# Should not happen, but just in case
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Remove the superfluous last samples of the audio chunk
|
||||
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
bit_rate: int | None = None,
|
||||
sample_rate: int | None = None,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encodes an audio file using ffmpeg."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
# Open input file
|
||||
with av.open(str(input_path), "r") as input:
|
||||
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
|
||||
|
||||
# Define sub-sampling options
|
||||
if sample_rate is None:
|
||||
sample_rate = input_stream.rate
|
||||
|
||||
# Create and open output file (overwrite by default)
|
||||
with av.open(str(output_path), "w") as output:
|
||||
output_stream = output.add_stream(
|
||||
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
|
||||
)
|
||||
|
||||
if bit_rate is not None:
|
||||
output_stream.bit_rate = bit_rate
|
||||
|
||||
# Loop through input WAV packets and encode them
|
||||
for input_frame in input.decode(
|
||||
input_stream
|
||||
): # This step handles both demuxing and decoding under the hood
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Reset logging level
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_path.exists():
|
||||
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
@@ -245,6 +245,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
@@ -512,6 +526,13 @@ def compute_episode_stats(
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
|
||||
@@ -26,6 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -51,7 +52,8 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -1083,3 +1085,561 @@ def _copy_episodes_metadata_and_stats(
|
||||
else:
|
||||
if src_dataset.meta.stats:
|
||||
write_stats(src_dataset.meta.stats, dst_meta.root)
|
||||
|
||||
|
||||
def _save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def _save_batch_episodes_images(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
num_workers: int = 4,
|
||||
) -> list[float]:
|
||||
"""Save images from multiple episodes to disk for batch video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_indices: List of episode indices to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
|
||||
Returns:
|
||||
List of episode durations in seconds
|
||||
"""
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
frame_idx = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
episode_durations.append(episode_length / dataset.fps)
|
||||
|
||||
# Get episode images
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Save images
|
||||
items = list(enumerate(episode_dataset))
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
frame_idx += episode_length
|
||||
|
||||
return episode_durations
|
||||
|
||||
|
||||
def _iter_episode_batches(
|
||||
episode_indices: list[int],
|
||||
episode_lengths: dict[int, int],
|
||||
size_per_frame_mb: float,
|
||||
video_file_size_limit: float,
|
||||
max_episodes: int | None,
|
||||
max_frames: int | None,
|
||||
):
|
||||
"""Generator that yields batches of episode indices for video encoding.
|
||||
|
||||
Groups episodes into batches that respect size and memory constraints:
|
||||
- Stays under video file size limit
|
||||
- Respects maximum episodes per batch (if specified)
|
||||
- Respects maximum frames per batch (if specified)
|
||||
|
||||
Args:
|
||||
episode_indices: List of episode indices to batch
|
||||
episode_lengths: Dictionary mapping episode index to episode length
|
||||
size_per_frame_mb: Estimated size per frame in MB
|
||||
video_file_size_limit: Maximum video file size in MB
|
||||
max_episodes: Maximum number of episodes per batch (None = no limit)
|
||||
max_frames: Maximum number of frames per batch (None = no limit)
|
||||
|
||||
Yields:
|
||||
List of episode indices for each batch
|
||||
"""
|
||||
batch_episodes = []
|
||||
estimated_size = 0.0
|
||||
total_frames = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep_length = episode_lengths[ep_idx]
|
||||
ep_estimated_size = ep_length * size_per_frame_mb
|
||||
|
||||
# we check if adding this episode would exceed any constraint
|
||||
would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit
|
||||
would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes
|
||||
would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames
|
||||
|
||||
if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames):
|
||||
# yield current batch before adding this episode
|
||||
yield batch_episodes
|
||||
# start a new batch with current episode
|
||||
batch_episodes = [ep_idx]
|
||||
estimated_size = ep_estimated_size
|
||||
total_frames = ep_length
|
||||
else:
|
||||
# add to current batch
|
||||
batch_episodes.append(ep_idx)
|
||||
estimated_size += ep_estimated_size
|
||||
total_frames += ep_length
|
||||
|
||||
# yield final batch if not empty
|
||||
if batch_episodes:
|
||||
yield batch_episodes
|
||||
|
||||
|
||||
def _estimate_frame_size_via_calibration(
|
||||
dataset: LeRobotDataset,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
|
||||
Encodes a representative sample of frames using the exact codec parameters
|
||||
to measure actual compression ratio, which is more accurate than heuristics.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images.
|
||||
img_key: Image key to calibrate (e.g., "observation.images.top").
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
Estimated size in MB per frame based on actual encoding.
|
||||
"""
|
||||
calibration_dir = temp_dir / "calibration" / img_key
|
||||
calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Select a representative episode (prefer middle episode if available)
|
||||
calibration_ep_idx = episode_indices[len(episode_indices) // 2]
|
||||
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
|
||||
# Use up to num_calibration_frames from this episode
|
||||
num_frames = min(num_calibration_frames, episode_length)
|
||||
|
||||
# Get frames from dataset
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames
|
||||
for i, idx in enumerate(sample_indices):
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
encode_video_frames(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Measure actual compressed size
|
||||
video_size_bytes = calibration_video_path.stat().st_size
|
||||
video_size_mb = video_size_bytes / BYTES_PER_MIB
|
||||
size_per_frame_mb = video_size_mb / num_frames
|
||||
|
||||
logging.info(
|
||||
f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB "
|
||||
f"= {size_per_frame_mb:.4f} MB/frame for {img_key}"
|
||||
)
|
||||
|
||||
return size_per_frame_mb
|
||||
|
||||
finally:
|
||||
# Clean up calibration files
|
||||
if calibration_dir.exists():
|
||||
shutil.rmtree(calibration_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
# Video conversion constants
|
||||
BYTES_PER_KIB = 1024
|
||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
max_frames_per_batch: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-to-video dataset.
|
||||
|
||||
Creates a new LeRobotDataset with images encoded as videos, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process all episodes and batch encode videos
|
||||
# Use dictionary for O(1) episode metadata lookups instead of O(n) linear search
|
||||
all_episode_metadata = {}
|
||||
fps = int(dataset.fps)
|
||||
|
||||
try:
|
||||
# Build episode metadata entries first
|
||||
logging.info("Building episode metadata...")
|
||||
cumulative_frame_idx = 0
|
||||
for ep_idx in episode_indices:
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
ep_length = src_episode["length"]
|
||||
ep_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": ep_length,
|
||||
"dataset_from_index": cumulative_frame_idx,
|
||||
"dataset_to_index": cumulative_frame_idx + ep_length,
|
||||
}
|
||||
if "data/chunk_index" in src_episode:
|
||||
ep_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
ep_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
all_episode_metadata[ep_idx] = ep_meta
|
||||
cumulative_frame_idx += ep_length
|
||||
|
||||
# Process each camera and batch encode multiple episodes together
|
||||
video_file_size_limit = new_meta.video_files_size_in_mb
|
||||
|
||||
# Pre-compute episode lengths for batching
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
dataset=dataset,
|
||||
img_key=img_key,
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
chunk_idx, file_idx = 0, 0
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Process episodes in batches to stay under size limit
|
||||
for batch_episodes in _iter_episode_batches(
|
||||
episode_indices=episode_indices,
|
||||
episode_lengths=episode_lengths,
|
||||
size_per_frame_mb=size_per_frame_mb,
|
||||
video_file_size_limit=video_file_size_limit,
|
||||
max_episodes=max_episodes_per_batch,
|
||||
max_frames=max_frames_per_batch,
|
||||
):
|
||||
total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes)
|
||||
logging.info(
|
||||
f" Encoding batch of {len(batch_episodes)} episodes "
|
||||
f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames"
|
||||
)
|
||||
|
||||
# Save images for all episodes in this batch
|
||||
imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key
|
||||
episode_durations = _save_batch_episodes_images(
|
||||
dataset=dataset,
|
||||
imgs_dir=imgs_dir,
|
||||
img_key=img_key,
|
||||
episode_indices=batch_episodes,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
# Encode all batched episodes into single video
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Update metadata for each episode in the batch
|
||||
for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True):
|
||||
from_timestamp = cumulative_timestamp
|
||||
to_timestamp = cumulative_timestamp + duration
|
||||
cumulative_timestamp = to_timestamp
|
||||
|
||||
# Find episode metadata entry and add video metadata (O(1) dictionary lookup)
|
||||
ep_meta = all_episode_metadata[ep_idx]
|
||||
ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx
|
||||
ep_meta[f"videos/{img_key}/file_index"] = file_idx
|
||||
ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp
|
||||
ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp
|
||||
|
||||
# Move to next video file for next batch
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size)
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(list(all_episode_metadata.values()))
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
@@ -33,12 +33,16 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.datasets.audio_utils import decode_audio, encode_audio, get_audio_info
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
INFO_PATH,
|
||||
_validate_feature_names,
|
||||
check_delta_timestamps,
|
||||
@@ -68,16 +72,19 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
concatenate_media_files,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_media_duration_in_s,
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.microphones import Microphone
|
||||
from lerobot.microphones.utils import async_microphones_start_recording
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -213,6 +220,19 @@ class LeRobotDatasetMetadata:
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
|
||||
file_idx = ep[f"audio/{audio_key}/file_index"]
|
||||
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
@@ -223,6 +243,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def audio_path(self) -> str | None:
|
||||
"""Formattable string for the audio files."""
|
||||
return self.info["audio_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
@@ -253,6 +278,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -293,6 +323,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def audio_files_size_in_mb(self) -> int:
|
||||
"""Max size of audio file in mega bytes."""
|
||||
return self.info["audio_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
@@ -434,11 +469,27 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_audio_info(self, audio_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if audio_key is not None and audio_key not in self.audio_keys:
|
||||
raise ValueError(f"Audio key {audio_key} not found in dataset")
|
||||
|
||||
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
|
||||
for key in audio_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
@@ -450,6 +501,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
@@ -466,6 +518,11 @@ class LeRobotDatasetMetadata:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
if audio_files_size_in_mb is not None:
|
||||
if audio_files_size_in_mb <= 0:
|
||||
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
|
||||
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -473,12 +530,13 @@ class LeRobotDatasetMetadata:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
"audio_files_size_in_mb": self.audio_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
@@ -505,6 +563,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
@@ -528,6 +587,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
audio_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
@@ -540,11 +600,13 @@ class LeRobotDatasetMetadata:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
||||
def _encode_video_worker(
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -561,8 +623,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -594,6 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
task-conditioned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
|
||||
|
||||
A typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
@@ -619,19 +685,37 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.parquet
|
||||
└── videos
|
||||
├── observation.images.laptop
|
||||
├── videos
|
||||
│ ├── observation.images.laptop
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ ├── observation.images.phone
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
└── audio
|
||||
├── observation.audio.laptop
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── observation.images.phone
|
||||
├── observation.audio.phone
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
@@ -671,12 +755,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -685,9 +776,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.audio_backend = (
|
||||
audio_backend if audio_backend else "torchcodec"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = vcodec
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -756,6 +851,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
@@ -764,6 +860,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
@@ -818,7 +916,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download(self, download_videos: bool = True) -> None:
|
||||
def download(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
@@ -826,8 +924,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
@@ -842,6 +944,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self.meta.audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
@@ -854,7 +965,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
@@ -882,6 +993,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
# Check if all required audio files exist
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
if not audio_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
@@ -925,17 +1044,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps.
|
||||
|
||||
Args:
|
||||
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
|
||||
ep_idx: The episode index.
|
||||
|
||||
Returns:
|
||||
A tuple of (query_indices, padding) where:
|
||||
- query_indices: Dict mapping keys to lists of absolute indices to query
|
||||
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -947,7 +1079,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.meta.video_keys:
|
||||
for key in self.meta.video_keys + self.meta.audio_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
if self._absolute_to_relative_idx is not None:
|
||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||
@@ -962,7 +1094,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""
|
||||
Query dataset for indices across keys, skipping video keys.
|
||||
Query dataset for indices across keys, skipping video keys and audio keys.
|
||||
|
||||
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||
|
||||
@@ -974,7 +1106,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
if key in self.meta.video_keys or key in self.meta.audio_keys:
|
||||
continue
|
||||
# Map absolute indices to relative indices if needed
|
||||
relative_indices = (
|
||||
@@ -1009,6 +1141,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(
|
||||
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
|
||||
audio_chunk = decode_audio(
|
||||
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
|
||||
)
|
||||
item[audio_key] = audio_chunk.squeeze(0)
|
||||
|
||||
return item
|
||||
|
||||
def _ensure_hf_dataset_loaded(self):
|
||||
"""Lazy load the HF dataset only when needed for reading."""
|
||||
if self._lazy_loading or self.hf_dataset is None:
|
||||
@@ -1027,20 +1181,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
# Use the absolute index from the dataset for delta timestamp calculations
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||
item = {**item, **video_frames, **audio_chunks}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
@@ -1088,6 +1245,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
@@ -1140,11 +1301,43 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self.features[key]["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||
"""
|
||||
|
||||
audio_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
|
||||
microphone.start_recording(output_file=audio_file)
|
||||
|
||||
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
|
||||
"""
|
||||
|
||||
output_files = []
|
||||
for microphone_key in microphones:
|
||||
output_files.append(
|
||||
self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
|
||||
)
|
||||
|
||||
async_microphones_start_recording(microphones, output_files)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
@@ -1188,6 +1381,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
@@ -1196,9 +1395,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
has_audio_keys = len(self.meta.audio_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
if (has_video_keys or has_audio_keys) and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1211,6 +1411,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
@@ -1234,21 +1435,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# TODO(Caroline): add parallel encoding for audio as well
|
||||
for audio_key in self.meta.audio_keys:
|
||||
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
if (has_video_keys or has_audio_keys) and use_batched_encoding:
|
||||
# Check if we should trigger batch encoding
|
||||
self.episodes_since_last_encoding += 1
|
||||
if self.episodes_since_last_encoding == self.batch_encoding_size:
|
||||
start_ep = self.num_episodes - self.batch_encoding_size
|
||||
end_ep = self.num_episodes
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_video_keys:
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_audio_keys:
|
||||
self._batch_save_episode_audio(start_ep, end_ep)
|
||||
self.episodes_since_last_encoding = 0
|
||||
|
||||
if not episode_data:
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
self.clear_episode_buffer(
|
||||
delete_images=len(self.meta.image_keys) > 0, delete_audio=len(self.meta.audio_keys) > 0
|
||||
)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
@@ -1299,7 +1509,70 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
# Save the current episode's audio metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(video_ep_df)
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
Batch save audio for multiple episodes.
|
||||
|
||||
Args:
|
||||
start_episode: Starting episode index (inclusive)
|
||||
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
|
||||
"""
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[start_episode]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding audio for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
or self.meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||
):
|
||||
# The current episode is in a new chunk or file.
|
||||
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
# Load new episode dataframe
|
||||
chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[ep_idx]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
# Save the current episode's video metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
@@ -1410,7 +1683,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
@@ -1456,7 +1729,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest video file
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
@@ -1478,7 +1751,79 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
|
||||
# Encode episode audio into a temporary audio file
|
||||
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
or self.meta.latest_episode is None
|
||||
or f"audio/{audio_key}/chunk_index" not in self.meta.latest_episode
|
||||
):
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
old_chunk_idx = self.meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
|
||||
old_file_idx = self.meta.episodes[-1][f"audio/{audio_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self.meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated audio file using latest_episode
|
||||
latest_ep = self.meta.latest_episode
|
||||
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
|
||||
|
||||
latest_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.audio_files_size_in_mb:
|
||||
# Move temporary episode audio to a new audio file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest audio file
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
|
||||
if episode_index == 0:
|
||||
self.meta.update_audio_info(audio_key)
|
||||
write_info(self.meta.info, self.meta.root) # ensure audio info always written properly
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx,
|
||||
f"audio/{audio_key}/file_index": file_idx,
|
||||
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1487,11 +1832,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.camera_keys:
|
||||
for cam_key in self.meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Clean up audio files for the current episode buffer
|
||||
if delete_audio:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
if audio_file.is_file():
|
||||
audio_file.unlink()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
@@ -1526,7 +1881,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
|
||||
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert raw audio files into m4a audio files.
|
||||
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since audio encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{audio_key}_{episode_index:03d}.m4a"
|
||||
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
encode_audio(raw_audio_file, temp_path, overwrite=True)
|
||||
raw_audio_file.unlink()
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1541,9 +1908,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1560,6 +1931,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_writer = None
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
@@ -1581,6 +1953,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
obj.audio_backend = (
|
||||
audio_backend if audio_backend is not None else "torchcodec"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1601,6 +1976,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
@@ -1618,6 +1994,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
audio_backend=audio_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
@@ -18,12 +18,12 @@ from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None
|
||||
action: RobotAction | None = None, observation: RobotObservation | None = None
|
||||
) -> dict[PipelineFeatureType, dict[str, Any]]:
|
||||
"""
|
||||
Creates the initial features dict for the dataset from action and observation specs.
|
||||
|
||||
@@ -36,6 +36,7 @@ from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from soundfile import read
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -50,6 +51,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
@@ -57,13 +59,19 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
AUDIO_DIR = "audio"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||
|
||||
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
@@ -408,6 +416,16 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
|
||||
# Fill missing channel dimension when loading mono audio data
|
||||
if audio_data.ndim == 1:
|
||||
audio_data = np.expand_dims(audio_data, axis=1)
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
@@ -576,7 +594,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -639,7 +657,12 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
cam_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
mic_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
|
||||
}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
@@ -662,6 +685,14 @@ def hw_to_dataset_features(
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
for key, parameters in mic_fts.items():
|
||||
features[f"{prefix}.audio.{key}"] = {
|
||||
"dtype": "audio",
|
||||
"shape": (len(parameters[1]),),
|
||||
"names": ["channels"],
|
||||
"info": {"sample_rate": parameters[0]},
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
@@ -691,6 +722,8 @@ def build_dataset_frame(
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
elif ft["dtype"] == "audio":
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
|
||||
|
||||
return frame
|
||||
|
||||
@@ -724,6 +757,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif ft["dtype"] == "audio":
|
||||
type = FeatureType.AUDIO
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
@@ -802,6 +839,7 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
@@ -811,6 +849,10 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): The maximum number of files per chunk directory.
|
||||
data_files_size_in_mb (int | None): The maximum size for data files in MB.
|
||||
video_files_size_in_mb (int | None): The maximum size for video files in MB.
|
||||
audio_files_size_in_mb (int | None): The maximum size for audio files in MB.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
@@ -824,10 +866,12 @@ def create_empty_dataset_info(
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"audio_path": DEFAULT_AUDIO_PATH,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -1051,6 +1095,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "audio":
|
||||
return validate_feature_audio(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
@@ -1117,6 +1163,23 @@ def validate_feature_image_or_video(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
|
||||
-1
|
||||
]: # The number of frames might be different
|
||||
error_message += (
|
||||
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
|
||||
)
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
@@ -1172,12 +1235,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -59,6 +59,8 @@ from requests import HTTPError
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -79,7 +81,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -311,12 +313,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -352,7 +354,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -367,8 +369,124 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def get_audio_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
audio_keys = [key for key, ft in features.items() if ft["dtype"] == "audio"]
|
||||
return audio_keys
|
||||
|
||||
|
||||
def convert_audios(root: Path, new_root: Path, audio_file_size_in_mb: int):
|
||||
logging.info(f"Converting audios from {root} to {new_root}")
|
||||
|
||||
audio_keys = get_audio_keys(root)
|
||||
if len(audio_keys) == 0:
|
||||
return None
|
||||
|
||||
audio_keys = sorted(audio_keys)
|
||||
|
||||
eps_metadata_per_mic = []
|
||||
for microphone in audio_keys:
|
||||
eps_metadata = convert_audios_of_microphone(root, new_root, microphone, audio_file_size_in_mb)
|
||||
eps_metadata_per_mic.append(eps_metadata)
|
||||
|
||||
num_eps_per_mic = [len(eps_mic_map) for eps_mic_map in eps_metadata_per_mic]
|
||||
if len(set(num_eps_per_mic)) != 1:
|
||||
raise ValueError(f"All microphones dont have same number of episodes ({num_eps_per_mic}).")
|
||||
|
||||
episodes_metadata = []
|
||||
num_microphones = len(audio_keys)
|
||||
num_episodes = num_eps_per_mic[0]
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert audios"):
|
||||
# Sanity check
|
||||
ep_ids = [
|
||||
eps_metadata_per_mic[mic_idx][ep_idx]["episode_index"] for mic_idx in range(num_microphones)
|
||||
]
|
||||
ep_ids += [ep_idx]
|
||||
if len(set(ep_ids)) != 1:
|
||||
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||
|
||||
ep_dict = {}
|
||||
for mic_idx in range(num_microphones):
|
||||
ep_dict.update(eps_metadata_per_mic[mic_idx][ep_idx])
|
||||
episodes_metadata.append(ep_dict)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def convert_audios_of_microphone(root: Path, new_root: Path, audio_key: str, audio_file_size_in_mb: int):
|
||||
# Access old paths to m4a
|
||||
audios_dir = root / "audio"
|
||||
ep_paths = sorted(audios_dir.glob(f"*/{audio_key}/*.m4a"))
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert audios of {audio_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= audio_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the file we just saved
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
|
||||
|
||||
# Move to next file and start fresh with current episode
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
|
||||
# Add current episode metadata
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx, # Will be updated when file is saved
|
||||
f"audio/{audio_key}/file_index": file_idx, # Will be updated when file is saved
|
||||
f"audio/{audio_key}/from_timestamp": duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
episodes_metadata.append(ep_metadata)
|
||||
|
||||
# Add current episode to accumulation
|
||||
paths_to_cat.append(ep_path)
|
||||
size_in_mb += ep_size_in_mb
|
||||
duration_in_s += ep_duration_in_s
|
||||
ep_idx += 1
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the final file
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None, episodes_audios=None
|
||||
):
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
@@ -392,16 +510,30 @@ def generate_episode_metadata_dict(
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if episodes_audios is None:
|
||||
ep_audio = {}
|
||||
else:
|
||||
ep_audio = episodes_audios[i]
|
||||
ep_ids_set.add(ep_audio["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict = {
|
||||
**ep_metadata,
|
||||
**ep_video,
|
||||
**ep_audio,
|
||||
**ep_legacy_metadata,
|
||||
**flatten_dict({"stats": ep_stats}),
|
||||
}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
ep_dict["meta/episodes/file_index"] = 0
|
||||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
def convert_episodes_metadata(
|
||||
root, new_root, episodes_metadata, episodes_video_metadata=None, episodes_audio_metadata=None
|
||||
):
|
||||
logging.info(f"Converting episodes metadata from {root} to {new_root}")
|
||||
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
@@ -410,13 +542,19 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
if episodes_audio_metadata is not None:
|
||||
num_eps_set.add(len(episodes_audio_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
episodes_legacy_metadata,
|
||||
episodes_metadata,
|
||||
episodes_stats,
|
||||
episodes_video_metadata,
|
||||
episodes_audio_metadata,
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
@@ -425,20 +563,22 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
write_stats(stats, new_root)
|
||||
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb, audio_file_size_in_mb):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = data_file_size_in_mb
|
||||
info["video_files_size_in_mb"] = video_file_size_in_mb
|
||||
info["audio_files_size_in_mb"] = audio_file_size_in_mb
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
|
||||
info["audio_path"] = DEFAULT_AUDIO_PATH if info["audio_path"] is not None else None
|
||||
info["fps"] = int(info["fps"])
|
||||
logging.info(f"Converting info from {root} to {new_root}")
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
if info["features"][key]["dtype"] == "video" or info["features"][key]["dtype"] == "audio":
|
||||
# already has fps in video_info or audio_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
write_info(info, new_root)
|
||||
@@ -449,6 +589,7 @@ def convert_dataset(
|
||||
branch: str | None = None,
|
||||
data_file_size_in_mb: int | None = None,
|
||||
video_file_size_in_mb: int | None = None,
|
||||
audio_file_size_in_mb: int | None = None,
|
||||
root: str | Path | None = None,
|
||||
push_to_hub: bool = True,
|
||||
force_conversion: bool = False,
|
||||
@@ -457,6 +598,8 @@ def convert_dataset(
|
||||
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_file_size_in_mb is None:
|
||||
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_file_size_in_mb is None:
|
||||
audio_file_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
|
||||
# First check if the dataset already has a v3.0 version
|
||||
if root is None and not force_conversion:
|
||||
@@ -498,7 +641,10 @@ def convert_dataset(
|
||||
convert_tasks(root, new_root)
|
||||
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
|
||||
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
|
||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||
episodes_audios_metadata = convert_audios(root, new_root, audio_file_size_in_mb)
|
||||
convert_episodes_metadata(
|
||||
root, new_root, episodes_metadata, episodes_videos_metadata, episodes_audios_metadata
|
||||
)
|
||||
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
@@ -511,7 +657,7 @@ def convert_dataset(
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*", "audio/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
@@ -549,6 +695,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-file-size-in-mb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for audio.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
|
||||
@@ -397,42 +397,42 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
def concatenate_media_files(
|
||||
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
Concatenate multiple media files (video & audio) into a single media file using pyav.
|
||||
|
||||
This function takes a list of video input file paths and concatenates them into a single
|
||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
This function takes a list of input media file paths and concatenates them into a single
|
||||
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
concatenation without re-encoding.
|
||||
|
||||
Args:
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
input_media_paths: Ordered list of input media file paths to concatenate.
|
||||
output_media_path: Path to the output media file.
|
||||
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
||||
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
|
||||
codec, resolution, and frame rate for proper concatenation.
|
||||
"""
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
output_media_path = Path(output_media_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
if output_media_path.exists() and not overwrite:
|
||||
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_media_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
if len(input_media_paths) == 0:
|
||||
raise FileNotFoundError("No input media paths provided.")
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
# Create a temporary .ffconcat file to list the input media paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
for input_path in input_media_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
@@ -442,11 +442,12 @@ def concatenate_video_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
|
||||
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
|
||||
tmp_output_media_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
# Replicate input streams in output container
|
||||
@@ -461,6 +462,7 @@ def concatenate_video_files(
|
||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
last_dts = None
|
||||
for packet in input_container.demux():
|
||||
# Skip packets from un-mapped streams
|
||||
if packet.stream.index not in stream_map:
|
||||
@@ -469,6 +471,16 @@ def concatenate_video_files(
|
||||
# Skip demux flushing packets
|
||||
if packet.dts is None:
|
||||
continue
|
||||
else:
|
||||
# Enforce strictly increasing decoding timestamps (DTS)
|
||||
if last_dts is not None and packet.dts <= last_dts:
|
||||
shift = last_dts - packet.dts + 1
|
||||
packet.dts += shift
|
||||
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
|
||||
logging.warning(
|
||||
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
|
||||
)
|
||||
last_dts = packet.dts
|
||||
|
||||
output_stream = stream_map[packet.stream.index]
|
||||
packet.stream = output_stream
|
||||
@@ -476,7 +488,7 @@ def concatenate_video_files(
|
||||
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
shutil.move(tmp_output_media_path, output_media_path)
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
@@ -512,38 +524,6 @@ with warnings.catch_warnings():
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
@@ -573,9 +553,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -590,22 +567,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
Get the duration of a media file (video & audio) in seconds using PyAV.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
media_path: Path to the media file.
|
||||
|
||||
Returns:
|
||||
Duration of the video in seconds.
|
||||
Duration of the media file in seconds.
|
||||
"""
|
||||
with av.open(str(video_path)) as container:
|
||||
# Get the first video stream
|
||||
video_stream = container.streams.video[0]
|
||||
with av.open(str(media_path)) as container:
|
||||
# Get the first stream
|
||||
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
|
||||
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
||||
if video_stream.duration is not None:
|
||||
duration = float(video_stream.duration * video_stream.time_base)
|
||||
if stream.duration is not None:
|
||||
duration = float(stream.duration * stream.time_base)
|
||||
else:
|
||||
# Fallback to container duration if stream duration is not available
|
||||
duration = float(container.duration / av.time_base)
|
||||
@@ -614,12 +591,12 @@ def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
|
||||
class VideoEncodingManager:
|
||||
"""
|
||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
|
||||
|
||||
This manager handles:
|
||||
- Batch encoding for any remaining episodes when recording interrupted
|
||||
- Cleaning up temporary image files from interrupted episodes
|
||||
- Removing empty image directories
|
||||
- Cleaning up temporary image and audio files from interrupted episodes
|
||||
- Removing empty image and audio directories
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset instance
|
||||
@@ -646,6 +623,7 @@ class VideoEncodingManager:
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
self.dataset._batch_save_episode_audio(start_ep, end_ep)
|
||||
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
@@ -662,6 +640,15 @@ class VideoEncodingManager:
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
for key in self.dataset.meta.audio_keys:
|
||||
audio_file = self.dataset._get_raw_audio_file_path(
|
||||
episode_index=interrupted_episode_index, audio_key=key
|
||||
)
|
||||
if audio_file.exists():
|
||||
logging.debug(
|
||||
f"Cleaning up interrupted episode audio for episode {interrupted_episode_index}, microphone {key}"
|
||||
)
|
||||
audio_file.unlink()
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
@@ -675,4 +662,16 @@ class VideoEncodingManager:
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
# Clean up any remaining audio directory if it's empty
|
||||
audio_dir = self.dataset.root / "raw_audio"
|
||||
# Check for any remaining WAV files
|
||||
wav_files = list(audio_dir.rglob("*.wav"))
|
||||
if len(wav_files) == 0:
|
||||
# Only remove the raw_audio directory if no WAV files remain
|
||||
if audio_dir.exists():
|
||||
shutil.rmtree(audio_dir)
|
||||
logging.debug("Cleaned up empty audio directory")
|
||||
else:
|
||||
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401
|
||||
from .configs import AlohaEnv, EnvConfig, HubEnvConfig, PushtEnv # noqa: F401
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
@@ -68,6 +68,22 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubEnvConfig(EnvConfig):
|
||||
"""Base class for environments that delegate creation to a hub-hosted make_env.
|
||||
|
||||
Hub environments download and execute remote code from the HF Hub.
|
||||
The hub_path points to a repository containing an env.py with a make_env function.
|
||||
"""
|
||||
|
||||
hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py"
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
# Not used for hub environments - the hub's make_env handles everything
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
@@ -368,3 +384,71 @@ class MetaworldEnv(EnvConfig):
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("isaaclab_arena")
|
||||
@dataclass
|
||||
class IsaaclabArenaEnv(HubEnvConfig):
|
||||
hub_path: str = "nvidia/isaaclab-arena-envs"
|
||||
episode_length: int = 300
|
||||
num_envs: int = 1
|
||||
embodiment: str | None = "gr1_pink"
|
||||
object: str | None = "power_drill"
|
||||
mimic: bool = False
|
||||
teleop_device: str | None = None
|
||||
seed: int | None = 42
|
||||
device: str | None = "cuda:0"
|
||||
disable_fabric: bool = False
|
||||
enable_cameras: bool = False
|
||||
headless: bool = False
|
||||
enable_pinocchio: bool = True
|
||||
environment: str | None = "gr1_microwave"
|
||||
task: str | None = "Reach out to the microwave and open it."
|
||||
state_dim: int = 54
|
||||
action_dim: int = 36
|
||||
camera_height: int = 512
|
||||
camera_width: int = 512
|
||||
video: bool = False
|
||||
video_length: int = 100
|
||||
video_interval: int = 200
|
||||
# Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
|
||||
state_keys: str = "robot_joint_pos"
|
||||
# Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
|
||||
# Set to None or "" for environments without cameras
|
||||
camera_keys: str | None = None
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
kwargs: dict | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.kwargs:
|
||||
# dynamically convert kwargs to fields in the dataclass
|
||||
# NOTE! the new fields will not bee seen by the dataclass repr
|
||||
field_names = {f.name for f in fields(self)}
|
||||
for key, value in self.kwargs.items():
|
||||
if key not in field_names and key != "kwargs":
|
||||
setattr(self, key, value)
|
||||
self.kwargs = None
|
||||
|
||||
# Set action feature
|
||||
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
||||
self.features_map[ACTION] = ACTION
|
||||
|
||||
# Set state feature
|
||||
self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
|
||||
self.features_map[OBS_STATE] = OBS_STATE
|
||||
|
||||
# Add camera features for each camera key
|
||||
if self.enable_cameras and self.camera_keys:
|
||||
for cam_key in self.camera_keys.split(","):
|
||||
cam_key = cam_key.strip()
|
||||
if cam_key:
|
||||
self.features[cam_key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.camera_height, self.camera_width, 3),
|
||||
)
|
||||
self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
@@ -20,11 +20,11 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -73,6 +73,26 @@ def make_env_pre_post_processors(
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
if env_cfg.state_keys:
|
||||
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
|
||||
else:
|
||||
state_keys = ()
|
||||
if env_cfg.camera_keys:
|
||||
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
|
||||
else:
|
||||
camera_keys = ()
|
||||
if not state_keys and not camera_keys:
|
||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||
preprocessor_steps.append(
|
||||
IsaaclabArenaProcessorStep(
|
||||
state_keys=state_keys,
|
||||
camera_keys=camera_keys,
|
||||
)
|
||||
)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||
|
||||
@@ -98,7 +118,6 @@ def make_env(
|
||||
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
|
||||
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
|
||||
Default False — must be set to True to import/exec hub `env.py`.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
@@ -112,19 +131,35 @@ def make_env(
|
||||
"""
|
||||
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
|
||||
# simplified: only support hub-provided `make_env`
|
||||
# TODO: (jadechoghari): deprecate string API and remove this check
|
||||
if isinstance(cfg, str):
|
||||
hub_path: str | None = cfg
|
||||
elif isinstance(cfg, HubEnvConfig):
|
||||
hub_path = cfg.hub_path
|
||||
else:
|
||||
hub_path = None
|
||||
|
||||
# If hub_path is set, download and call hub-provided `make_env`
|
||||
if hub_path:
|
||||
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False
|
||||
repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
|
||||
repo_id, file_path, local_file, revision = _download_hub_file(
|
||||
hub_path, trust_remote_code, hub_cache_dir
|
||||
)
|
||||
|
||||
# import and surface clear import errors
|
||||
module = _import_hub_module(local_file, repo_id)
|
||||
|
||||
# call the hub-provided make_env
|
||||
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
env_cfg = None if isinstance(cfg, str) else cfg
|
||||
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg)
|
||||
|
||||
# normalize the return into {suite: {task_id: vec_env}}
|
||||
return _normalize_hub_result(raw_result)
|
||||
|
||||
# At this point, cfg must be an EnvConfig (not a string) since hub_path would have been set otherwise
|
||||
if isinstance(cfg, str):
|
||||
raise TypeError("cfg should be an EnvConfig at this point")
|
||||
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs` must be at least 1")
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
@@ -237,7 +239,7 @@ class LiberoEnv(gym.Env):
|
||||
env.reset()
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
|
||||
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
|
||||
images = {}
|
||||
for camera_name in self.camera_name:
|
||||
image = raw_obs[camera_name]
|
||||
@@ -291,9 +293,9 @@ class LiberoEnv(gym.Env):
|
||||
def reset(self, seed=None, **kwargs):
|
||||
super().reset(seed=seed)
|
||||
self._env.seed(seed)
|
||||
if self.init_states and self._init_states is not None:
|
||||
self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
@@ -313,7 +315,7 @@ class LiberoEnv(gym.Env):
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
if action.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
|
||||
@@ -25,6 +25,8 @@ import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
|
||||
try:
|
||||
@@ -161,7 +163,7 @@ class MetaworldEnv(gym.Env):
|
||||
env._freeze_rand_vec = False # otherwise no randomization
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
|
||||
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
|
||||
image = None
|
||||
if self._env is not None:
|
||||
image = self._env.render()
|
||||
@@ -196,7 +198,7 @@ class MetaworldEnv(gym.Env):
|
||||
self,
|
||||
seed: int | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
) -> tuple[RobotObservation, dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
|
||||
@@ -204,7 +206,7 @@ class MetaworldEnv(gym.Env):
|
||||
seed (Optional[int]): Random seed for environment initialization.
|
||||
|
||||
Returns:
|
||||
observation (Dict[str, Any]): The initial formatted observation.
|
||||
observation (RobotObservation): The initial formatted observation.
|
||||
info (Dict[str, Any]): Additional info about the reset state.
|
||||
"""
|
||||
super().reset(seed=seed)
|
||||
@@ -216,7 +218,7 @@ class MetaworldEnv(gym.Env):
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
"""
|
||||
Perform one environment step.
|
||||
|
||||
@@ -224,7 +226,7 @@ class MetaworldEnv(gym.Env):
|
||||
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
|
||||
|
||||
Returns:
|
||||
observation (Dict[str, Any]): The formatted observation after the step.
|
||||
observation (RobotObservation): The formatted observation after the step.
|
||||
reward (float): The scalar reward for this step.
|
||||
terminated (bool): Whether the episode terminated successfully.
|
||||
truncated (bool): Whether the episode was truncated due to a time limit.
|
||||
|
||||
@@ -29,6 +29,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.processor import RobotObservation
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
@@ -46,7 +47,7 @@ def _convert_nested_dict(d):
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
# TODO(jadechoghari, imstevenpmwork): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -98,11 +99,19 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
if "robot_state" in observations:
|
||||
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
|
||||
|
||||
# Handle IsaacLab Arena format: observations have 'policy' and 'camera_obs' keys
|
||||
if "policy" in observations:
|
||||
return_observations[f"{OBS_STR}.policy"] = observations["policy"]
|
||||
|
||||
if "camera_obs" in observations:
|
||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# TODO(jadechoghari, imstevenpmwork): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
@@ -144,7 +153,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
task_result = env.call("task_description")
|
||||
@@ -302,7 +311,7 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any:
|
||||
return module
|
||||
|
||||
|
||||
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
|
||||
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None) -> Any:
|
||||
"""
|
||||
Ensure module exposes make_env and call it.
|
||||
"""
|
||||
@@ -311,7 +320,11 @@ def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
|
||||
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
|
||||
)
|
||||
entry_fn = module.make_env
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
# Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID)
|
||||
if cfg is not None:
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg)
|
||||
else:
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
|
||||
|
||||
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,5 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_so100_leader import BiSO100Leader
|
||||
from .config_bi_so100_leader import BiSO100LeaderConfig
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
from .utils import make_microphones_from_configs
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,5 +12,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so100_leader import SO100LeaderConfig
|
||||
from .so100_leader import SO100Leader
|
||||
from .configuration_anyskin import AnyskinSensorConfig
|
||||
from .sensor_anyskin import AnyskinSensor
|
||||
45
src/lerobot/microphones/anyskin/configuration_anyskin.py
Normal file
45
src/lerobot/microphones/anyskin/configuration_anyskin.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("anyskin")
|
||||
@dataclass
|
||||
class AnyskinSensorConfig(MicrophoneConfig):
|
||||
"""Configuration class for Anyskin tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||
|
||||
This class provides configuration options for Anyskin tactile sensors, including serial port, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
AnyskinSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||
AnyskinSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||
```
|
||||
|
||||
Attributes:
|
||||
sensor_port: Serial port of the tactile sensor.
|
||||
baud_rate: Baud rate of the tactile sensor.
|
||||
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||
channels: List of channel numbers to use for the tactile sensor.
|
||||
"""
|
||||
|
||||
sensor_port: str
|
||||
baud_rate: int = 115_200
|
||||
sensor_id: int = 0
|
||||
burst_mode: bool = True
|
||||
temp_filtered: bool = False
|
||||
473
src/lerobot/microphones/anyskin/sensor_anyskin.py
Normal file
473
src/lerobot/microphones/anyskin/sensor_anyskin.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the AnyskinSensor class for capturing tactile data from Anyskin tactile sensors.
|
||||
"""
|
||||
|
||||
from doctest import master
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.hub import T
|
||||
import numpy as np
|
||||
from serial import Serial, serialutil
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_anyskin import AnyskinSensorConfig
|
||||
|
||||
from anyskin import AnySkinBase, AnySkinDummy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MAGNETS_CHANNELS = 5
|
||||
|
||||
class AnyskinSensor(Microphone):
|
||||
"""
|
||||
The AnyskinSensor class handles all Anyskin tactile sensors.
|
||||
|
||||
A AnyskinSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import AnyskinSensorConfig
|
||||
|
||||
config = AnyskinSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||
microphone = AnyskinSensor(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: AnyskinSensorConfig):
|
||||
""" "
|
||||
Initializes the AnyskinSensor instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the sensor.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Sensor port
|
||||
self.sensor_port = config.sensor_port
|
||||
|
||||
# Baud rate
|
||||
self.baud_rate = config.baud_rate
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the sensor is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the sensor is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is recording, False otherwise.
|
||||
"""
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the sensor is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is writing to a file, False otherwise.
|
||||
"""
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available sensors connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected sensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the sensor.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("int16"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the busy_wait function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.sensor_port,
|
||||
self.baud_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
sensor_port,
|
||||
baud_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
) -> None:
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def tactile_callback(tactile_sensor: AnySkinBase):
|
||||
"""
|
||||
Parse the tactile data from the raw input data.
|
||||
"""
|
||||
if audio_callback_start_event.is_set():
|
||||
timestamp, indata = tactile_sensor.get_sample()
|
||||
indata = indata.reshape(-1, MAX_MAGNETS_CHANNELS)
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
try:
|
||||
tactile_sensor = AnySkinBase(
|
||||
num_mags=MAX_MAGNETS_CHANNELS,
|
||||
port=sensor_port,
|
||||
baudrate=baud_rate,
|
||||
burst_mode=True,
|
||||
device_id=0, #TODO(CarolinePascal): create an abstract increasing id for each sensor
|
||||
temp_filtered=False,
|
||||
) #TODO(CarolinePascal): add timeout on serial connection ?
|
||||
except (serialutil.SerialException, AttributeError) as e:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {sensor_port}: {e}")
|
||||
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
record_is_started_event.set()
|
||||
while not record_stop_event.is_set():
|
||||
tactile_callback(tactile_sensor) # Initial flush is already done in the constructor.
|
||||
record_is_started_event.clear()
|
||||
tactile_sensor.close() # Closes the inherited serial connection.
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect the sensor and release any resources.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start recording tactile data from the sensor.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded tactile data.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=AnyskinSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=AnyskinSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the sensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
tactile_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return tactile_readings
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""Internal loop run by the background thread for asynchronous reading."""
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the sensor."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="FLOAT", # Subtype for float32 values
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,13 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
import draccus
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_so100_leader")
|
||||
@dataclass
|
||||
class BiSO100LeaderConfig(TeleoperatorConfig):
|
||||
left_arm_port: str
|
||||
right_arm_port: str
|
||||
@dataclass(kw_only=True)
|
||||
class MicrophoneConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
sample_rate: int | None = None
|
||||
channels: list[int] | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
140
src/lerobot/microphones/microphone.py
Normal file
140
src/lerobot/microphones/microphone.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from threading import Barrier
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
|
||||
|
||||
class Microphone(abc.ABC):
|
||||
"""Base class for microphone implementations.
|
||||
|
||||
Defines a standard interface for microphone operations across different backends.
|
||||
Subclasses must implement all abstract methods.
|
||||
|
||||
Manages basic microphone properties (sample rate, channels) and core operations:
|
||||
- Connection/disconnection
|
||||
- Start/stop recording
|
||||
- Audio chunk reading
|
||||
|
||||
Attributes:
|
||||
sample_rate (int | None): Configured sample rate in Hz
|
||||
channels (list[int] | None): List of channel numbers to record
|
||||
|
||||
Example:
|
||||
class MyMicrophone(Microphone):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: MicrophoneConfig):
|
||||
"""Initialize the microphone with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Microphone configuration containing sample rate and channels.
|
||||
"""
|
||||
self.sample_rate: int | None = config.sample_rate
|
||||
self.channels: list[int] | None = config.channels
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the microphone is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the microphone is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is recording, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the microphone is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is writing to a file, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available microphones connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> None:
|
||||
"""Establish connection to the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""Start recording audio from the microphone.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded audio.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple microphones start recording at the same time.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the microphone.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the microphone and release any resources."""
|
||||
pass
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,5 +12,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so101_follower import SO101FollowerConfig
|
||||
from .so101_follower import SO101Follower
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
from .microphone_portaudio import PortAudioMicrophone
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user