mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
48 Commits
2686450d68
...
feat/vla-j
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac9128b319 | ||
|
|
61d82a5773 | ||
|
|
d8ed30d58c | ||
|
|
79f6756505 | ||
|
|
7a98c56ce4 | ||
|
|
9961eb6918 | ||
|
|
e1852db71a | ||
|
|
c2b29c8ae0 | ||
|
|
58eac863aa | ||
|
|
952e5146dc | ||
|
|
37fda2a6fc | ||
|
|
df7d5132d1 | ||
|
|
8efa5cabe9 | ||
|
|
7e23859c55 | ||
|
|
a24f669deb | ||
|
|
b7727b8a6c | ||
|
|
7db4414e6b | ||
|
|
7da594fda8 | ||
|
|
01ce5d7af1 | ||
|
|
83ef59e020 | ||
|
|
997b713f14 | ||
|
|
1e3d25f10e | ||
|
|
8c0efb8295 | ||
|
|
9ef0fd5433 | ||
|
|
26d2ac48a8 | ||
|
|
0f29cd3167 | ||
|
|
64c9570547 | ||
|
|
8b03d25fef | ||
|
|
5e37d97631 | ||
|
|
a71e0d34ad | ||
|
|
d00b3e993a | ||
|
|
596c72bfc6 | ||
|
|
ea535ad98d | ||
|
|
7368a0085a | ||
|
|
7dba4f19a9 | ||
|
|
90d398ea59 | ||
|
|
16a4643000 | ||
|
|
6fa78ca8b4 | ||
|
|
ec4bf4e47f | ||
|
|
da56489174 | ||
|
|
addb354296 | ||
|
|
848ed3240e | ||
|
|
f7bb1795e7 | ||
|
|
0355902fba | ||
|
|
24017e960c | ||
|
|
e86f5af5bf | ||
|
|
5c98e80430 | ||
|
|
f65f3f7a4a |
@@ -59,6 +59,10 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: vla_jepa
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
@@ -73,6 +77,8 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
|
||||
433
docs/source/molmoact2.mdx
Normal file
433
docs/source/molmoact2.mdx
Normal file
@@ -0,0 +1,433 @@
|
||||
# MolmoAct2 Policy
|
||||
|
||||
MolmoAct2 is the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
|
||||
training, evaluation, checkpointing, and dataset interfaces for easier use with
|
||||
LeRobot datasets.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Install LeRobot with the MolmoAct2 optional dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[molmoact2]"
|
||||
```
|
||||
|
||||
To run the models in this repository, you need an NVIDIA GPU. The measurements
|
||||
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
|
||||
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
|
||||
`gradient_checkpointing=true` and include the forward pass, backward pass,
|
||||
gradient clipping, optimizer step, and optimizer state allocation. Values are
|
||||
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
|
||||
dataloader workers, CUDA context, and fragmentation.
|
||||
|
||||
Multi-GPU training through `accelerate` increases throughput and global batch
|
||||
size, but this LeRobot port does not currently expose the original MolmoAct2
|
||||
`fsdp_devices` model-parallel training path. The current training script has
|
||||
not been tested for multi-node training.
|
||||
|
||||
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
|
||||
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
|
||||
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
|
||||
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
|
||||
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
|
||||
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
|
||||
|
||||
The repo has been tested with Ubuntu 22.04.
|
||||
|
||||
## Usage
|
||||
|
||||
To use MolmoAct2 in a LeRobot training config, set:
|
||||
|
||||
```python
|
||||
policy.type=molmoact2
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
|
||||
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
|
||||
the same LeRobot training loop, dataset transforms, checkpoint saving, and
|
||||
logging. The difference is only how the initial policy weights and processor
|
||||
state are loaded.
|
||||
|
||||
### Training With Original MolmoAct2 Weight
|
||||
|
||||
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
|
||||
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
|
||||
the original HF model files, then build its own policy processor from the
|
||||
dataset metadata and the policy options below.
|
||||
|
||||
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
|
||||
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
|
||||
augmentation, and LeRobot's checkpointing/logging path.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.setup_type="single franka robotic arm in libero" \
|
||||
--policy.control_mode="delta end-effector pose" \
|
||||
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.freeze_embedding=true \
|
||||
--policy.normalize_gripper=false \
|
||||
--policy.enable_knowledge_insulation=false \
|
||||
--policy.push_to_hub=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Training With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
|
||||
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
|
||||
restores the saved LeRobot policy config, model weights, processor, and
|
||||
normalization statistics. You can still override training-time options such as
|
||||
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.path=/path/to/pretrained_model \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Common Practices
|
||||
|
||||
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
|
||||
or a real-world dataset with less than 200 demonstrations, a global batch size of
|
||||
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
|
||||
cases, we intentionally keep the action expert fully trainable, which we found
|
||||
to be crucial for model performance. For larger fine-tuning datasets, larger
|
||||
global batch sizes and full fine-tuning are usually preferred.
|
||||
|
||||
### Common Policy Options
|
||||
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
|
||||
Use this for released MolmoAct2 weights.
|
||||
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
|
||||
created by LeRobot training.
|
||||
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
|
||||
`both`. `both` trains the flow-matching action expert and the discrete
|
||||
action-token loss.
|
||||
- `policy.train_action_expert_only`: trains only parameters whose names contain
|
||||
`action_expert`. It requires `policy.action_mode=continuous`.
|
||||
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
|
||||
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
|
||||
expert linear layers. When `policy.enable_lora_action_expert=false`, the
|
||||
action expert base weights remain fully trainable while the VLM is trained
|
||||
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
|
||||
action expert is also adapter-tuned instead of fully fine-tuned.
|
||||
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
|
||||
context K/V states before the action loss. The default is `false`.
|
||||
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
|
||||
`10`. This LeRobot port overrides the loaded checkpoint's
|
||||
`max_action_horizon` with this value.
|
||||
- `policy.n_action_steps`: number of actions consumed from each predicted
|
||||
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
|
||||
- `policy.setup_type`: text inserted into the prompt to describe the robot and
|
||||
scene, e.g. `single franka robotic arm in libero`. More examples are listed
|
||||
in the `metadata_by_tag` entries of
|
||||
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
|
||||
- `policy.control_mode`: text inserted into the prompt to describe the action
|
||||
space, e.g. `delta end-effector pose` or `absolute joint pose`.
|
||||
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
|
||||
processor.
|
||||
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
|
||||
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
|
||||
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
|
||||
example during training. We use `8` for fine-tuning.
|
||||
- `policy.num_inference_steps`: optional override for continuous action
|
||||
generation steps at inference time.
|
||||
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
|
||||
to reduce activation memory.
|
||||
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
|
||||
- `policy.normalize_gripper`: controls whether gripper dimensions are included
|
||||
in state/action quantile normalization. The default is `false`.
|
||||
- `policy.normalize_language`: normalizes task strings before prompt
|
||||
construction. The default is `true`.
|
||||
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
|
||||
Released checkpoints use `policy.expected_max_action_dim=32`.
|
||||
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
|
||||
infer it from images, state dimension, action dimension, action horizon, and
|
||||
discrete-action mode.
|
||||
|
||||
### Learning Rates
|
||||
|
||||
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
|
||||
fine-tuning experiments.
|
||||
|
||||
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
|
||||
`policy.optimizer_vit_lr=5e-6` for the vision tower,
|
||||
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
|
||||
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
|
||||
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
|
||||
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
|
||||
`policy.enable_lora_action_expert=false`, so the action expert is still fully
|
||||
fine-tuned with `policy.optimizer_action_expert_lr`. If
|
||||
`policy.enable_lora_action_expert=true`, the action expert is trained through
|
||||
LoRA adapters instead.
|
||||
- Action-expert-only fine-tuning trains only the action expert and uses
|
||||
`policy.optimizer_action_expert_lr=5e-5`.
|
||||
|
||||
You can override the full fine-tuning and action-expert learning rates with
|
||||
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
|
||||
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
|
||||
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
|
||||
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
|
||||
|
||||
### Dataset Quantile Statistics
|
||||
|
||||
MolmoAct2 defaults to quantile normalization for state and action features. If
|
||||
your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
Alternatively, train MolmoAct2 with mean/std normalization:
|
||||
|
||||
```bash
|
||||
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
|
||||
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
|
||||
fixed and use `policy.per_episode_seed=true`.
|
||||
|
||||
**Important:** We found that `num_steps_wait=10` does not reliably let the
|
||||
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
|
||||
results reported here use `num_steps_wait=50`.
|
||||
|
||||
### Evaluation With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
|
||||
normalization statistics are restored together with the model.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.use_amp=true \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
### Evaluation With Original MolmoAct2 Weight
|
||||
|
||||
You can evaluate a released Hugging Face checkpoint directly without first
|
||||
converting it to a LeRobot checkpoint. In this case, set
|
||||
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
|
||||
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
|
||||
normalization statistics, action horizon, prompt metadata, and image-key order
|
||||
from the checkpoint's `norm_stats.json`.
|
||||
|
||||
To fully replicate the MolmoAct2 paper results with released Hugging Face
|
||||
checkpoints, we recommend using the v0.5.1-pinned
|
||||
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
|
||||
branch. That branch matches the original evaluation settings used for the
|
||||
reported numbers.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.norm_tag=libero \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=float32 \
|
||||
--policy.use_amp=false \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_goal \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
|
||||
full LIBERO suite. The same command works for other released MolmoAct2
|
||||
checkpoints as long as the requested `policy.norm_tag` exists in that
|
||||
checkpoint's `norm_stats.json`.
|
||||
|
||||
### Common Evaluation Options
|
||||
|
||||
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
|
||||
flow-matching inference or `discrete` for action-token inference. It must be
|
||||
compatible with the training-time `policy.action_mode` saved in the
|
||||
checkpoint.
|
||||
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
|
||||
saved by LeRobot.
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
|
||||
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
|
||||
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
|
||||
image-key order, and action horizon from the original checkpoint's
|
||||
`norm_stats.json`. It is required for direct original-HF checkpoint
|
||||
evaluation.
|
||||
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
|
||||
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
|
||||
- `policy.use_amp`: runs the policy forward under autocast during eval. For
|
||||
`model_dtype=bfloat16`, keep this enabled.
|
||||
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
|
||||
graph path for faster repeated continuous-action rollout.
|
||||
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
|
||||
action generation deterministic per episode for replication.
|
||||
- `env.task`: comma-separated LIBERO suites or a single suite. Use
|
||||
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
|
||||
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
|
||||
by the policy processor.
|
||||
|
||||
## Performance Results
|
||||
|
||||
### LIBERO Benchmark Results
|
||||
|
||||
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
|
||||
compare and test its LeRobot implementation, we fine-tuned
|
||||
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
|
||||
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
|
||||
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
|
||||
results.
|
||||
|
||||
The LeRobot fine-tuned checkpoint reported here is available at
|
||||
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
|
||||
and was trained on
|
||||
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
|
||||
|
||||
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
|
||||
| -------------- | ---------------------: | -----------------: |
|
||||
| LIBERO Spatial | 98.4% | 97.8% |
|
||||
| LIBERO Object | 100.0% | 100.0% |
|
||||
| LIBERO Goal | 98.0% | 97.8% |
|
||||
| LIBERO 10 | 96.6% | 93.2% |
|
||||
| Average | 98.25% | 97.20% |
|
||||
|
||||
These results demonstrate MolmoAct2's strong performance across diverse robotic
|
||||
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
|
||||
evaluation section.
|
||||
|
||||
## Differences From the Original Implementation
|
||||
|
||||
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
|
||||
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
|
||||
differences from the original training repository are:
|
||||
|
||||
- The original paper training stack loads the model in fp32 and trains under
|
||||
mixed precision. This LeRobot port usually loads the checkpoint directly in
|
||||
`policy.model_dtype=bfloat16` for lower memory use.
|
||||
- The original repository uses its own FSDP/model-parallel training path. The
|
||||
LeRobot port uses the standard LeRobot/Accelerate training path and has not
|
||||
been tested for multi-node training.
|
||||
- The original repository supports sequence packing. The LeRobot port trains on
|
||||
one LeRobot sample per item and pads to an inferred fixed sequence budget.
|
||||
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
|
||||
dataset transforms, image augmentation, and Weights & Biases logging
|
||||
conventions.
|
||||
- The original training path supports mixed action horizons by padding to
|
||||
`max_action_horizon` and masking padded horizon slots in the action expert
|
||||
self-attention. This is useful when training across datasets with different
|
||||
control frequencies. The LeRobot port currently targets single-dataset
|
||||
fine-tuning, so `policy.chunk_size` overrides the checkpoint
|
||||
`max_action_horizon` and horizon masking is not implemented yet. Support for
|
||||
this mixed-horizon path is planned.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
39
docs/source/policy_molmoact2_README.md
Normal file
39
docs/source/policy_molmoact2_README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# MolmoAct2
|
||||
|
||||
This repository contains the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
|
||||
training, evaluation, checkpointing, and dataset compatibility.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## LIBERO Evaluation
|
||||
|
||||
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
|
||||
scene stabilize and can degrade measured success. All LIBERO evaluation results
|
||||
reported for this LeRobot implementation use `num_steps_wait=50`.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
39
docs/source/policy_vla_jepa_README.md
Normal file
39
docs/source/policy_vla_jepa_README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
177
docs/source/topreward.mdx
Normal file
177
docs/source/topreward.mdx
Normal file
@@ -0,0 +1,177 @@
|
||||
# TOPReward
|
||||
|
||||
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
|
||||
|
||||
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
|
||||
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
|
||||
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
|
||||
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Overview
|
||||
|
||||
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
|
||||
|
||||
- A trajectory video (a sequence of frames).
|
||||
- A task instruction (e.g. _"open the drawer"_).
|
||||
|
||||
it builds a chat prompt of the form
|
||||
|
||||
```text
|
||||
<video>
|
||||
"The above video shows a robot manipulation trajectory that completes the
|
||||
following task: <instruction> Decide whether the above statement is True
|
||||
or not. The answer is: True"
|
||||
```
|
||||
|
||||
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward.
|
||||
|
||||
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the label-masking logic. The processor owns the tokeniser and builds the full chat prompt (EO-1/Robometer pattern).
|
||||
|
||||
## What the LeRobot integration covers
|
||||
|
||||
- Standard `reward_model.type=topreward` configuration through LeRobot.
|
||||
- VLM loading via the `transformers` `Qwen3VLForConditionalGeneration` API.
|
||||
- Prompt assembly + tokenisation in the processor (matching upstream `QwenClient.compute_instruction_reward`).
|
||||
- `compute_reward()` returns one scalar log-prob per sample.
|
||||
- LeRobot reward-model save/load — `save_pretrained` writes only `config.json` (the VLM is identified by `vlm_name`).
|
||||
- An offline labeling script that writes a `topreward_progress.parquet` (SARM-compatible schema) for RA-BC and overlay.
|
||||
|
||||
The current LeRobot port supports the **Qwen3-VL client only**. Other upstream clients (Gemini, OpenAI, Gemma, Molmo) can be added as follow-up extras.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot following the [Installation Guide](./installation).
|
||||
2. Install the TOPReward optional extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[topreward]"
|
||||
```
|
||||
|
||||
or, with `uv` from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra topreward
|
||||
```
|
||||
|
||||
This pulls in `transformers`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
TOPReward expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | --------------------------- | --------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data for the task string |
|
||||
| `reward_model.max_frames` | `16` | Cap on frames per sample |
|
||||
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
|
||||
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
|
||||
|
||||
The model returns:
|
||||
|
||||
- `compute_reward(batch)`: one log-probability per sample. Higher = better task-video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the reward model directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
)
|
||||
reward_model = TOPRewardModel(cfg)
|
||||
```
|
||||
|
||||
### Use the reward factory
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"topreward",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor tokenises the full prompt (video + prefix + instruction suffix), writes Qwen-VL tensors + `prompt_length` under `observation.topreward.*`. The model reads those tensors, label-masks based on `prompt_length`, and extracts the log-prob reward.
|
||||
|
||||
### Offline dataset labeling
|
||||
|
||||
Write a `topreward_progress.parquet` for RA-BC training and overlay videos:
|
||||
|
||||
```bash
|
||||
# Sparse-dense (15 anchors per episode, matches upstream)
|
||||
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||
--dataset-repo-id lerobot/libero_10_image \
|
||||
--num-samples 15 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
Then render the progress overlay for any episode:
|
||||
|
||||
```bash
|
||||
uv run examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot/libero_10_image \
|
||||
--episode 0 \
|
||||
--progress-file topreward_progress.parquet \
|
||||
--gif
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Prompt knobs
|
||||
|
||||
The default prompt mirrors the upstream paper:
|
||||
|
||||
```text
|
||||
prompt_prefix = "The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
prompt_suffix_template = "{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
```
|
||||
|
||||
Both are exposed on `TOPRewardConfig` for ablation. The suffix template **must** contain `{instruction}`.
|
||||
|
||||
### Chat template
|
||||
|
||||
`add_chat_template=True` wraps the full prompt (including instruction) with the tokenizer's chat template before tokenisation. Default is `False`, matching the upstream paper's main experiments.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot port is **inference-only and zero-shot**; `forward()` is not overridden and `is_trainable` returns `False`.
|
||||
- Only the **Qwen3-VL family** is supported; other upstream clients are out of scope.
|
||||
- TOPReward inherits the underlying VLM's biases.
|
||||
|
||||
## References
|
||||
|
||||
- [TOPReward project page](https://topreward.github.io/webpage/)
|
||||
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
|
||||
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
|
||||
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2026topreward,
|
||||
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
|
||||
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
|
||||
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
|
||||
Krishna, Ranjay},
|
||||
journal={arXiv preprint arXiv:2602.19313},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The original TOPReward codebase is MIT-licensed. The LeRobot port follows the LeRobot Apache 2.0 license; the wrapped Qwen3-VL weights are subject to the original Qwen license.
|
||||
235
docs/source/vla_jepa.mdx
Normal file
235
docs/source/vla_jepa.mdx
Normal file
@@ -0,0 +1,235 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -198,6 +198,7 @@ wallx = [
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
@@ -211,9 +212,11 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -274,10 +277,12 @@ all = [
|
||||
"lerobot[multi_task_dit]",
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -288,6 +293,7 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -403,8 +409,11 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
"arange",
|
||||
"is_compileable",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE"
|
||||
"OT_VALUE",
|
||||
"VanderBilt"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
else:
|
||||
del config_data[field]
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
@@ -43,6 +44,7 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
|
||||
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
@@ -56,6 +57,7 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -88,7 +90,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
|
||||
"molmoact2".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -151,6 +154,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "molmoact2":
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -168,7 +179,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x".
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -203,6 +214,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -231,6 +246,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
dataset_meta: Any | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -406,6 +422,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -414,6 +431,23 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MolmoAct2Config):
|
||||
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
processors = make_molmoact2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
@@ -499,6 +533,10 @@ def make_policy(
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
if ds_meta is not None:
|
||||
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
|
||||
if callable(set_dataset_feature_metadata):
|
||||
set_dataset_feature_metadata(ds_meta.features)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc = AutoProcessor.from_pretrained(
|
||||
str(cache_dir),
|
||||
trust_remote_code=True,
|
||||
fix_mistral_regex=False,
|
||||
)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
1
src/lerobot/policies/molmoact2/README.md
Symbolic link
1
src/lerobot/policies/molmoact2/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_molmoact2_README.md
|
||||
21
src/lerobot/policies/molmoact2/__init__.py
Normal file
21
src/lerobot/policies/molmoact2/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_molmoact2 import MolmoAct2Config
|
||||
from .modeling_molmoact2 import MolmoAct2Policy
|
||||
from .processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]
|
||||
519
src/lerobot/policies/molmoact2/configuration_molmoact2.py
Normal file
519
src/lerobot/policies/molmoact2/configuration_molmoact2.py
Normal file
@@ -0,0 +1,519 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import (
|
||||
AdamWConfig,
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
LRSchedulerConfig,
|
||||
OptimizerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from ..rtc.configuration_rtc import RTCConfig
|
||||
|
||||
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
|
||||
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
|
||||
MOLMOACT2_TASK_TOKEN_BUDGET = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
|
||||
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
|
||||
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
|
||||
|
||||
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
|
||||
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
|
||||
training steps"; build() is the first point where num_training_steps is known.
|
||||
"""
|
||||
|
||||
num_decay_steps: int | None
|
||||
|
||||
def build(self, optimizer, num_training_steps: int):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.peak_lr,
|
||||
decay_lr=self.decay_lr,
|
||||
num_warmup_steps=self.num_warmup_steps,
|
||||
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
|
||||
).build(optimizer, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def _round_up(value: int, multiple: int) -> int:
|
||||
return int(math.ceil(value / multiple) * multiple)
|
||||
|
||||
|
||||
def infer_molmoact2_max_sequence_length(
|
||||
*,
|
||||
num_images: int,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
action_horizon: int,
|
||||
include_discrete_action: bool,
|
||||
) -> int:
|
||||
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
|
||||
if num_images < 1:
|
||||
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim < 0:
|
||||
state_dim = 0
|
||||
if action_dim < 1:
|
||||
action_dim = 1
|
||||
if action_horizon < 1:
|
||||
action_horizon = 1
|
||||
|
||||
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
|
||||
prompt_tokens = (
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
|
||||
+ MOLMOACT2_TASK_TOKEN_BUDGET
|
||||
+ state_dim
|
||||
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
|
||||
)
|
||||
action_tokens = 0
|
||||
if include_discrete_action:
|
||||
action_tokens_per_step = max(
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
|
||||
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
|
||||
)
|
||||
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
|
||||
|
||||
return _round_up(
|
||||
image_tokens + prompt_tokens + action_tokens,
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
|
||||
)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("molmoact2")
|
||||
@dataclass
|
||||
class MolmoAct2Config(PreTrainedConfig):
|
||||
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
|
||||
|
||||
checkpoint_path: str = "allenai/MolmoAct2"
|
||||
checkpoint_revision: str | None = None
|
||||
checkpoint_force_download: bool = False
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 30
|
||||
n_action_steps: int = 30
|
||||
|
||||
action_mode: str = "both"
|
||||
inference_action_mode: str | None = None
|
||||
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
|
||||
discrete_generation_max_steps: int | None = None
|
||||
norm_tag: str | None = None
|
||||
|
||||
setup_type: str = ""
|
||||
control_mode: str = ""
|
||||
image_keys: list[str] = field(default_factory=list)
|
||||
normalize_language: bool = True
|
||||
add_setup_tokens: bool = True
|
||||
add_control_tokens: bool = True
|
||||
normalize_gripper: bool = False
|
||||
num_state_tokens: int = 256
|
||||
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
|
||||
# image/prompt/state/action token layout. Override only for unusual long prompts.
|
||||
max_sequence_length: int | None = None
|
||||
|
||||
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
|
||||
expected_max_action_dim: int = 32
|
||||
|
||||
# Flow-matching training knobs copied from the original MolmoAct2 training path.
|
||||
num_flow_timesteps: int = 8
|
||||
flow_matching_cutoff: float = 1.0
|
||||
flow_matching_time_offset: float = 0.001
|
||||
flow_matching_time_scale: float = 0.999
|
||||
flow_matching_beta_alpha: float = 1.0
|
||||
flow_matching_beta_beta: float = 1.5
|
||||
num_inference_steps: int | None = None
|
||||
mask_action_dim_padding: bool = True
|
||||
enable_inference_cuda_graph: bool = True
|
||||
# MolmoAct2-local eval option. When enabled, stochastic continuous action
|
||||
# generation uses a rollout-local generator derived from eval_seed.
|
||||
per_episode_seed: bool = False
|
||||
eval_seed: int | None = None
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
# Default is full finetuning with gradients from the action expert flowing into the VLM.
|
||||
enable_lora_vlm: bool = False
|
||||
lora_rank: int = 64
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.05
|
||||
lora_bias: str = "none"
|
||||
enable_lora_action_expert: bool = False
|
||||
enable_knowledge_insulation: bool = False
|
||||
freeze_embedding: bool = True
|
||||
train_action_expert_only: bool = False
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
model_dtype: str = "bfloat16"
|
||||
softmax_auxiliary_loss: bool = True
|
||||
softmax_auxiliary_loss_scale: float = 1e-4
|
||||
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_vit_lr: float = 5e-6
|
||||
optimizer_connector_lr: float = 5e-6
|
||||
optimizer_action_expert_lr: float = 5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-6
|
||||
optimizer_weight_decay: float = 0.0
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 200
|
||||
scheduler_decay_steps: int | None = None
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES,
|
||||
"ACTION": NormalizationMode.QUANTILES,
|
||||
}
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.action_mode not in {"continuous", "discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mode={self.action_mode!r}. "
|
||||
"Expected one of {'continuous', 'discrete', 'both'}."
|
||||
)
|
||||
if self.inference_action_mode not in {None, "continuous", "discrete"}:
|
||||
raise ValueError(
|
||||
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
|
||||
"Expected one of {None, 'continuous', 'discrete'}."
|
||||
)
|
||||
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
|
||||
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
|
||||
if self.train_action_expert_only and self.action_mode != "continuous":
|
||||
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
|
||||
if self.train_action_expert_only and self.enable_lora_vlm:
|
||||
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
|
||||
if self.enable_lora_action_expert and not self.enable_lora_vlm:
|
||||
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
|
||||
if self.chunk_size < 1:
|
||||
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
|
||||
if self.n_action_steps < 1:
|
||||
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
|
||||
)
|
||||
if self.expected_max_action_dim != 32:
|
||||
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
|
||||
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
|
||||
raise ValueError(
|
||||
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
|
||||
)
|
||||
if self.lora_rank < 1:
|
||||
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
|
||||
if self.lora_alpha < 1:
|
||||
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
|
||||
if not 0 <= self.lora_dropout <= 1:
|
||||
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
|
||||
if self.lora_bias not in {"none", "all", "lora_only"}:
|
||||
raise ValueError(
|
||||
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
|
||||
)
|
||||
if self.discrete_loss_token_weighting not in {
|
||||
"none",
|
||||
"token",
|
||||
"root_tokens",
|
||||
"root_subsegments",
|
||||
"root_subsegments_root_tokens",
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
|
||||
)
|
||||
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
|
||||
raise ValueError(
|
||||
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
|
||||
)
|
||||
if self.max_sequence_length is not None and self.max_sequence_length < 1:
|
||||
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
|
||||
|
||||
def inferred_max_sequence_length(
|
||||
self,
|
||||
*,
|
||||
num_images: int | None = None,
|
||||
state_dim: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
action_horizon: int | None = None,
|
||||
include_discrete_action: bool | None = None,
|
||||
) -> int:
|
||||
if self.max_sequence_length is not None:
|
||||
return int(self.max_sequence_length)
|
||||
|
||||
if num_images is None:
|
||||
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim is None:
|
||||
state_feature = self.robot_state_feature
|
||||
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
|
||||
if action_dim is None:
|
||||
action_feature = self.action_feature
|
||||
action_dim = (
|
||||
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
|
||||
)
|
||||
if action_horizon is None:
|
||||
action_horizon = self.chunk_size
|
||||
if include_discrete_action is None:
|
||||
include_discrete_action = self.action_mode in {"discrete", "both"}
|
||||
|
||||
return infer_molmoact2_max_sequence_length(
|
||||
num_images=int(num_images),
|
||||
state_dim=int(state_dim),
|
||||
action_dim=int(action_dim),
|
||||
action_horizon=int(action_horizon),
|
||||
include_discrete_action=bool(include_discrete_action),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
|
||||
self.dataset_feature_names = {}
|
||||
for key in (ACTION, OBS_STATE):
|
||||
feature = features.get(key) if isinstance(features, dict) else None
|
||||
if isinstance(feature, dict) and feature.get("names") is not None:
|
||||
self.dataset_feature_names[key] = feature["names"]
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up MolmoAct2 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"MolmoAct2 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(0,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.expected_max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def apply_norm_tag_metadata(self) -> None:
|
||||
if not str(self.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
self.checkpoint_path,
|
||||
revision=self.checkpoint_revision,
|
||||
force_download=bool(self.checkpoint_force_download),
|
||||
norm_tag=self.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
self.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
self.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not self.setup_type and metadata.get("setup_type") is not None:
|
||||
self.setup_type = str(metadata["setup_type"])
|
||||
if not self.control_mode and metadata.get("control_mode") is not None:
|
||||
self.control_mode = str(metadata["control_mode"])
|
||||
|
||||
def saved_policy_action_mode(self) -> str | None:
|
||||
pretrained_path = getattr(self, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
|
||||
return saved_policy_action_mode or self.action_mode
|
||||
|
||||
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
def validate_checkpoint_action_mode(
|
||||
self,
|
||||
checkpoint_action_mode: str,
|
||||
*,
|
||||
has_action_expert: bool,
|
||||
) -> None:
|
||||
if self.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
|
||||
def resolve_inference_action_mode(
|
||||
self,
|
||||
requested_mode: str | None,
|
||||
saved_policy_action_mode: str | None = None,
|
||||
) -> str:
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode is None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
17
src/lerobot/policies/molmoact2/hf_model/__init__.py
Normal file
17
src/lerobot/policies/molmoact2/hf_model/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
237
src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py
Normal file
237
src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_tokenizer_location(
|
||||
tokenizer_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
local_path = Path(str(tokenizer_path)).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=str(tokenizer_path),
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
class UniversalActionProcessor(ProcessorMixin):
|
||||
attributes: ClassVar[list[str]] = ["tokenizer"]
|
||||
tokenizer_class: str = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
min_token: int = 0,
|
||||
*,
|
||||
action_dim: int | None = None,
|
||||
time_horizon: int | None = None,
|
||||
):
|
||||
self.scale = scale
|
||||
self.vocab_size = vocab_size
|
||||
self.min_token = min_token
|
||||
|
||||
# Action horizon and dimension needed during decoding. These can be specified
|
||||
# in three ways (in order of priority):
|
||||
# 1. passed in as kwargs to decode()
|
||||
# 2. in the constructor
|
||||
# 3. cached from the last time decode() was called
|
||||
self.time_horizon = time_horizon
|
||||
self.action_dim = action_dim
|
||||
self.called_time_horizon = time_horizon
|
||||
self.called_action_dim = action_dim
|
||||
|
||||
super().__init__(tokenizer)
|
||||
self.bpe_tokenizer = self.tokenizer
|
||||
|
||||
def __call__(self, action_chunk: np.array) -> np.array:
|
||||
from scipy.fft import dct
|
||||
|
||||
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk[None, ...]
|
||||
|
||||
# Cache the time horizon and action dimension for decoding
|
||||
self.called_time_horizon = action_chunk.shape[-2]
|
||||
self.called_action_dim = action_chunk.shape[-1]
|
||||
|
||||
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
|
||||
dct_coeff = np.around(dct_coeff * self.scale)
|
||||
tokens = []
|
||||
for elem in dct_coeff:
|
||||
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
|
||||
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
|
||||
return tokens
|
||||
|
||||
def decode(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> np.array:
|
||||
from scipy.fft import idct
|
||||
|
||||
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
|
||||
self.action_dim = action_dim or self.action_dim or self.called_action_dim
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
@classmethod
|
||||
def fit(
|
||||
cls,
|
||||
action_data: list[np.array],
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> "UniversalActionProcessor":
|
||||
from scipy.fft import dct
|
||||
|
||||
# Run DCT over all inputs
|
||||
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
|
||||
|
||||
# Quantize and find min token
|
||||
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
|
||||
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
|
||||
min_vocab_size = max_token - min_token
|
||||
|
||||
assert min_vocab_size <= vocab_size, (
|
||||
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
|
||||
)
|
||||
if min_vocab_size + 100 > vocab_size:
|
||||
logging.warning(
|
||||
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
|
||||
f"size {vocab_size}, consider increasing vocab size"
|
||||
)
|
||||
|
||||
# Make token iterator for BPE training
|
||||
def _token_iter():
|
||||
for tokens in dct_tokens:
|
||||
rounded_tokens = np.around(tokens * scale) - min_token
|
||||
rounded_tokens = rounded_tokens.astype(int)
|
||||
string = "".join(map(chr, rounded_tokens))
|
||||
yield string
|
||||
|
||||
# Train BPE tokenizer
|
||||
bpe = ByteLevelBPETokenizer()
|
||||
|
||||
# Set up the entire range of possible tokens as the initial alphabet
|
||||
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
min_frequency=2,
|
||||
show_progress=True,
|
||||
special_tokens=[],
|
||||
initial_alphabet=alphabet,
|
||||
max_token_length=10000,
|
||||
)
|
||||
|
||||
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
|
||||
# because it doesn't support custom alphabets)
|
||||
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
|
||||
|
||||
return cls(
|
||||
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
|
||||
scale=scale,
|
||||
vocab_size=vocab_size,
|
||||
min_token=min_token,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_local(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> "UniversalActionProcessor":
|
||||
location = Path(
|
||||
_resolve_tokenizer_location(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
processor_config = {}
|
||||
processor_config_path = location / "processor_config.json"
|
||||
if processor_config_path.exists():
|
||||
import json
|
||||
|
||||
processor_config = json.loads(processor_config_path.read_text())
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
|
||||
return cls(
|
||||
tokenizer,
|
||||
scale=processor_config.get("scale", 10),
|
||||
vocab_size=processor_config.get("vocab_size", 1024),
|
||||
min_token=processor_config.get("min_token", 0),
|
||||
action_dim=processor_config.get("action_dim"),
|
||||
time_horizon=processor_config.get("time_horizon"),
|
||||
)
|
||||
@@ -0,0 +1,553 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
MolmoAct2 configuration
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MolmoAct2VitConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
||||
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
||||
defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig
|
||||
>>> configuration = MolmoAct2VitConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
||||
>>> model = MolmoAct2VisionTransformer(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
base_config_key = "vit_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
num_hidden_layers: int = 27,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 72,
|
||||
hidden_act: str = "gelu_pytorch_tanh",
|
||||
layer_norm_eps: float = 1e-6,
|
||||
image_default_input_size: tuple[int, int] = (378, 378),
|
||||
image_patch_size: int = 14,
|
||||
image_num_pos: int = 577,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
float32_attention: bool = True,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_default_input_size = image_default_input_size
|
||||
self.image_patch_size = image_patch_size
|
||||
self.image_num_pos = image_num_pos
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.float32_attention = float32_attention
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
h, w = self.image_default_input_size
|
||||
return h // self.image_patch_size, w // self.image_patch_size
|
||||
|
||||
|
||||
class MolmoAct2AdapterConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
||||
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
||||
defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
||||
>>> vit_config = MolmoAct2VitConfig()
|
||||
>>> adapter_config = MolmoPoolingConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
||||
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> vit_configuration = model.vit_config
|
||||
>>> adapter_configuration = model.adapter_config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
base_config_key = "adapter_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_layers: tuple = (-3, -9),
|
||||
pooling_attention_mask: bool = False,
|
||||
hidden_size: int = 1152,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 72,
|
||||
float32_attention: bool = True,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
hidden_act: str = "silu",
|
||||
intermediate_size: int = 18944,
|
||||
text_hidden_size: int = 3584,
|
||||
image_feature_dropout: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||
self.vit_layers = vit_layers
|
||||
self.pooling_attention_mask = pooling_attention_mask
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.float32_attention = float32_attention
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.text_hidden_size = text_hidden_size
|
||||
self.image_feature_dropout = image_feature_dropout
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class MolmoAct2TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
||||
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
||||
|
||||
>>> # Initializing a MolmoAct2TextConfig
|
||||
>>> configuration = MolmoAct2TextConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
||||
>>> model = MolmoAct2TextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"blocks.*.self_attn.att_proj": "colwise",
|
||||
"blocks.*.self_attn.attn_out": "rowwise",
|
||||
"blocks.*.mlp.ff_proj": "colwise",
|
||||
"blocks.*.mlp.ff_out": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"wte": (["input_ids"], ["inputs_embeds"]),
|
||||
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"ln_f": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 3584,
|
||||
num_attention_heads: int = 28,
|
||||
num_key_value_heads: int | None = 4,
|
||||
head_dim: int = 128,
|
||||
vocab_size: int = 152064,
|
||||
additional_vocab_size: int = 128,
|
||||
qkv_bias: bool = True,
|
||||
num_hidden_layers: int = 48,
|
||||
intermediate_size: int = 18944,
|
||||
hidden_act: str = "silu",
|
||||
embedding_dropout: float = 0.0,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
max_position_embeddings: int = 4096,
|
||||
rope_theta: float = 1000000.0,
|
||||
rope_scaling: dict[str, Any] = None,
|
||||
rope_scaling_layers: list[int] | None = None,
|
||||
use_qk_norm: bool = False,
|
||||
qk_norm_type: str = "olmo",
|
||||
layer_norm_eps: int = 1e-6,
|
||||
norm_after: bool = False,
|
||||
initializer_range: float = 0.02,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.additional_vocab_size = additional_vocab_size
|
||||
self.qkv_bias = qkv_bias
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_scaling_layers = rope_scaling_layers
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.qk_norm_type = qk_norm_type
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.norm_after = norm_after
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
||||
r"""Configuration for the MolmoAct2 modern action expert."""
|
||||
|
||||
model_type = "molmoact2_action_expert"
|
||||
base_config_key = "action_expert_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_action_horizon: int = 32,
|
||||
max_action_dim: int = 32,
|
||||
hidden_size: int = 1024,
|
||||
num_layers: int = 32,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 8.0 / 3.0,
|
||||
ffn_multiple_of: int = 256,
|
||||
timestep_embed_dim: int = 256,
|
||||
dropout: float = 0.0,
|
||||
attn_dropout: float = 0.0,
|
||||
context_layer_norm: bool = True,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_eps: float = 1e-6,
|
||||
rope: bool = True,
|
||||
causal_attn: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_action_horizon = max_action_horizon
|
||||
self.max_action_dim = max_action_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.ffn_multiple_of = ffn_multiple_of
|
||||
self.timestep_embed_dim = timestep_embed_dim
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = attn_dropout
|
||||
self.context_layer_norm = context_layer_norm
|
||||
self.qk_norm = qk_norm
|
||||
self.qk_norm_eps = qk_norm_eps
|
||||
self.rope = rope
|
||||
self.causal_attn = causal_attn
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
|
||||
# them out of the public nested config avoids duplicated sources of truth.
|
||||
output.pop("max_action_horizon", None)
|
||||
output.pop("max_action_dim", None)
|
||||
return output
|
||||
|
||||
|
||||
class MolmoAct2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
||||
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig
|
||||
>>> vit_config = MolmoAct2VitConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2AdapterConfig
|
||||
>>> adapter_config = MolmoAct2AdapterConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2TextConfig
|
||||
>>> text_config = MolmoAct2TextConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2Config
|
||||
>>> configuration = MolmoAct2Config(
|
||||
>>> vit_config=vit_config,
|
||||
>>> adapter_config=adapter_config,
|
||||
>>> text_config=text_config,
|
||||
>>> image_start_token_id=151936,
|
||||
>>> image_end_token_id=151937,
|
||||
>>> image_patch_id=151938,
|
||||
>>> image_col_id=151939,
|
||||
>>> low_res_image_start_token_id=151940,
|
||||
>>> image_low_res_id=151942,
|
||||
>>> frame_start_token_id=151943,
|
||||
>>> frame_end_token_id=151944,
|
||||
>>> )
|
||||
|
||||
>>> # Initializing a model
|
||||
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
sub_configs = {
|
||||
"text_config": MolmoAct2TextConfig,
|
||||
"vit_config": MolmoAct2VitConfig,
|
||||
"adapter_config": MolmoAct2AdapterConfig,
|
||||
"action_expert_config": MolmoAct2ActionExpertConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_config: MolmoAct2VitConfig = None,
|
||||
adapter_config: MolmoAct2AdapterConfig = None,
|
||||
text_config: MolmoAct2TextConfig = None,
|
||||
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
||||
image_start_token_id: int = None,
|
||||
low_res_image_start_token_id: int = None,
|
||||
image_end_token_id: int = None,
|
||||
image_low_res_id: int = None,
|
||||
image_patch_id: int = None,
|
||||
image_col_id: int = None,
|
||||
frame_start_token_id: int = None,
|
||||
frame_end_token_id: int = None,
|
||||
use_frame_special_tokens: bool = True,
|
||||
initializer_range: float = 0.02,
|
||||
add_action_expert: bool = True,
|
||||
max_action_dim: int = 32,
|
||||
max_action_horizon: int = 30,
|
||||
n_obs_steps: int = 30,
|
||||
action_mode: str = "both",
|
||||
state_format: str = "discrete",
|
||||
flow_matching_num_steps: int = 10,
|
||||
flow_matching_cutoff: float = 1.0,
|
||||
flow_matching_time_offset: float = 0.001,
|
||||
flow_matching_time_scale: float = 0.999,
|
||||
flow_matching_beta_alpha: float = 1.0,
|
||||
flow_matching_beta_beta: float = 1.5,
|
||||
mask_action_dim_padding: bool = True,
|
||||
enable_depth_reasoning: bool = False,
|
||||
depth_mode: int = 2,
|
||||
num_depth_codes: int = 100,
|
||||
action_expert_depth_gate: bool = False,
|
||||
action_expert_depth_gate_per_layer: bool = False,
|
||||
action_expert_depth_gate_init_bias: float = -4.0,
|
||||
action_output_token_id: int = None,
|
||||
action_start_token_id: int = None,
|
||||
action_end_token_id: int = None,
|
||||
action_token_start_id: int = None,
|
||||
num_action_tokens: int = 0,
|
||||
depth_output_token_id: int = None,
|
||||
depth_start_token_id: int = None,
|
||||
depth_end_token_id: int = None,
|
||||
depth_token_start_id: int = None,
|
||||
num_depth_tokens: int = 0,
|
||||
state_start_token_id: int = None,
|
||||
state_end_token_id: int = None,
|
||||
state_token_start_id: int = None,
|
||||
num_state_tokens: int = 0,
|
||||
add_setup_tokens: bool = True,
|
||||
add_control_tokens: bool = True,
|
||||
norm_stats_filename: str = "norm_stats.json",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if vit_config is None:
|
||||
self.vit_config = MolmoAct2VitConfig()
|
||||
elif isinstance(vit_config, dict):
|
||||
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
||||
else:
|
||||
self.vit_config = vit_config
|
||||
if adapter_config is None:
|
||||
self.adapter_config = MolmoAct2AdapterConfig()
|
||||
elif isinstance(adapter_config, dict):
|
||||
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
||||
else:
|
||||
self.adapter_config = adapter_config
|
||||
if text_config is None:
|
||||
self.text_config = MolmoAct2TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
self.text_config = MolmoAct2TextConfig(**text_config)
|
||||
else:
|
||||
self.text_config = text_config
|
||||
self.add_action_expert = bool(add_action_expert)
|
||||
if not self.add_action_expert:
|
||||
self.action_expert_config = None
|
||||
elif action_expert_config is None:
|
||||
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
||||
max_action_horizon=max_action_horizon,
|
||||
max_action_dim=max_action_dim,
|
||||
num_layers=self.text_config.num_hidden_layers,
|
||||
)
|
||||
elif isinstance(action_expert_config, dict):
|
||||
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
||||
else:
|
||||
self.action_expert_config = action_expert_config
|
||||
if self.add_action_expert:
|
||||
self.action_expert_config.max_action_dim = int(max_action_dim)
|
||||
self.action_expert_config.max_action_horizon = int(max_action_horizon)
|
||||
self._validate_release_action_config(
|
||||
state_format=state_format,
|
||||
)
|
||||
self.image_start_token_id = image_start_token_id
|
||||
self.low_res_image_start_token_id = low_res_image_start_token_id
|
||||
self.image_end_token_id = image_end_token_id
|
||||
self.image_low_res_id = image_low_res_id
|
||||
self.image_high_res_id = image_patch_id
|
||||
self.image_patch_id = image_patch_id
|
||||
self.image_col_id = image_col_id
|
||||
self.frame_start_token_id = frame_start_token_id
|
||||
self.frame_end_token_id = frame_end_token_id
|
||||
self.use_frame_special_tokens = use_frame_special_tokens
|
||||
self.initializer_range = initializer_range
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_action_horizon = max_action_horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.action_mode = action_mode
|
||||
self.state_format = state_format
|
||||
self.flow_matching_num_steps = flow_matching_num_steps
|
||||
self.flow_matching_cutoff = flow_matching_cutoff
|
||||
self.flow_matching_time_offset = flow_matching_time_offset
|
||||
self.flow_matching_time_scale = flow_matching_time_scale
|
||||
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
||||
self.flow_matching_beta_beta = flow_matching_beta_beta
|
||||
self.mask_action_dim_padding = mask_action_dim_padding
|
||||
self.enable_depth_reasoning = enable_depth_reasoning
|
||||
self.depth_mode = depth_mode
|
||||
self.num_depth_codes = num_depth_codes
|
||||
self.action_expert_depth_gate = action_expert_depth_gate
|
||||
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
||||
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
||||
self.action_output_token_id = action_output_token_id
|
||||
self.action_start_token_id = action_start_token_id
|
||||
self.action_end_token_id = action_end_token_id
|
||||
self.action_token_start_id = action_token_start_id
|
||||
self.num_action_tokens = num_action_tokens
|
||||
self.depth_output_token_id = depth_output_token_id
|
||||
self.depth_start_token_id = depth_start_token_id
|
||||
self.depth_end_token_id = depth_end_token_id
|
||||
self.depth_token_start_id = depth_token_start_id
|
||||
self.num_depth_tokens = num_depth_tokens
|
||||
self.state_start_token_id = state_start_token_id
|
||||
self.state_end_token_id = state_end_token_id
|
||||
self.state_token_start_id = state_token_start_id
|
||||
self.num_state_tokens = num_state_tokens
|
||||
self.add_setup_tokens = add_setup_tokens
|
||||
self.add_control_tokens = add_control_tokens
|
||||
self.norm_stats_filename = norm_stats_filename
|
||||
|
||||
@staticmethod
|
||||
def _validate_release_action_config(
|
||||
*,
|
||||
state_format: str,
|
||||
) -> None:
|
||||
if state_format != "discrete":
|
||||
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
assert self.vit_config is not None
|
||||
return self.vit_config.image_num_patch
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.text_config.num_attention_heads
|
||||
|
||||
@property
|
||||
def num_key_value_heads(self):
|
||||
return self.text_config.num_key_value_heads
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.text_config.head_dim
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.text_config.num_hidden_layers
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.text_config.hidden_size
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.text_config.vocab_size
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.text_config.max_position_embeddings
|
||||
|
||||
|
||||
MolmoAct2VitConfig.register_for_auto_class()
|
||||
MolmoAct2AdapterConfig.register_for_auto_class()
|
||||
MolmoAct2TextConfig.register_for_auto_class()
|
||||
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
||||
MolmoAct2Config.register_for_auto_class()
|
||||
@@ -0,0 +1,564 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Image processor class for MolmoAct2"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import numpy as np
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_flat_list_of_images,
|
||||
valid_images,
|
||||
to_numpy_array,
|
||||
)
|
||||
from transformers.image_transforms import convert_to_rgb
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def normalize_image(
|
||||
image: np.ndarray,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
) -> np.ndarray:
|
||||
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
desired_output_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
dtype = image.dtype
|
||||
if torch.is_floating_point(image):
|
||||
in_min = 0.0
|
||||
in_max = 1.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||
|
||||
resized = resized.to(torch.float32)
|
||||
resized = (resized - in_min) / (in_max - in_min)
|
||||
|
||||
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def select_tiling(h, w, patch_size, max_num_crops):
|
||||
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||
original_size = np.stack([h, w]) # [1, 2]
|
||||
original_res = h * w
|
||||
tilings = []
|
||||
for i in range(1, max_num_crops + 1):
|
||||
for j in range(1, max_num_crops + 1):
|
||||
if i * j <= max_num_crops:
|
||||
tilings.append((i, j))
|
||||
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
||||
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
||||
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
||||
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
||||
|
||||
# How much we would need to scale the image to fit exactly in each tiling
|
||||
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
||||
|
||||
# The original size can be zero in rare cases if the image is smaller than the margin
|
||||
# In those cases letting the scale become infinite means the tiling is based on the
|
||||
# other side, or falls back to the smallest tiling
|
||||
with np.errstate(divide="ignore"):
|
||||
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
|
||||
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
||||
if np.all(required_scale < 1):
|
||||
# We are forced to downscale, so try to minimize the amount of downscaling
|
||||
ix = np.argmax(required_scale)
|
||||
else:
|
||||
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
||||
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
||||
ix = np.argmin(required_scale)
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
def build_resized_image(
|
||||
image: np.ndarray,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
resized = resize_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
)
|
||||
resized = normalize_image(resized, image_mean, image_std)
|
||||
if len(resized.shape) == 3:
|
||||
resized = np.expand_dims(resized, 0)
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||
return resized, resize_idx
|
||||
|
||||
|
||||
def build_overlapping_crops(
|
||||
image: np.ndarray,
|
||||
max_crops: int,
|
||||
overlap_margins: list[int],
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Decompose an image into a set of overlapping crops
|
||||
|
||||
:return crop_arr: [n_crops, h, w, 3] The crops
|
||||
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
||||
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
||||
"""
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
assert base_image_input_size[0] == base_image_input_size[1]
|
||||
|
||||
left_margin, right_margin = overlap_margins
|
||||
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
||||
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
||||
crop_window_size = crop_window_patches * image_patch_size
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
|
||||
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
||||
# as if we had an image without the margins and were using a crop size without the margins
|
||||
tiling = select_tiling(
|
||||
original_image_h - total_margin_pixels,
|
||||
original_image_w - total_margin_pixels,
|
||||
crop_window_size,
|
||||
max_crops,
|
||||
)
|
||||
|
||||
src = resize_image(
|
||||
image,
|
||||
[
|
||||
tiling[0] * crop_window_size + total_margin_pixels,
|
||||
tiling[1] * crop_window_size + total_margin_pixels,
|
||||
],
|
||||
resample,
|
||||
)
|
||||
src = normalize_image(src, image_mean, image_std)
|
||||
|
||||
# Now we have to split the image into crops, and track what patches came from
|
||||
# where in `patch_idx_arr`
|
||||
n_crops = tiling[0] * tiling[1]
|
||||
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
||||
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
||||
on_crop = 0
|
||||
for i in range(tiling[0]):
|
||||
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
||||
# which results in overlapping crop windows
|
||||
y0 = i * crop_window_size
|
||||
for j in range(tiling[1]):
|
||||
x0 = j * crop_window_size
|
||||
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
|
||||
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
||||
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
||||
|
||||
# Mask out idx that are in the overlap region
|
||||
if i != 0:
|
||||
patch_idx[:left_margin, :] = -1
|
||||
if j != 0:
|
||||
patch_idx[:, :left_margin] = -1
|
||||
if i != tiling[0] - 1:
|
||||
patch_idx[-right_margin:, :] = -1
|
||||
if j != tiling[1] - 1:
|
||||
patch_idx[:, -right_margin:] = -1
|
||||
patch_idx_arr[on_crop] = patch_idx
|
||||
on_crop += 1
|
||||
|
||||
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
||||
# so it is ordered left-to-right order
|
||||
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
|
||||
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
||||
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
||||
|
||||
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
||||
# to the correct patch it should come from in `crop_arr`
|
||||
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
||||
src.shape[0] // image_patch_size,
|
||||
src.shape[1] // image_patch_size,
|
||||
)
|
||||
return crop_arr, patch_idx_arr
|
||||
|
||||
|
||||
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||
if len(array.shape) == 3:
|
||||
n_crops, h, w = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||
return array
|
||||
else:
|
||||
n_crops, h, w, c = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||
return array
|
||||
|
||||
|
||||
def arange_for_pooling(
|
||||
idx_arr: np.ndarray,
|
||||
pool_h: int,
|
||||
pool_w: int,
|
||||
) -> np.ndarray:
|
||||
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||
idx_arr = np.pad(
|
||||
idx_arr,
|
||||
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||
mode="constant",
|
||||
constant_values=-1,
|
||||
)
|
||||
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||
|
||||
|
||||
def image_to_patches_and_grids(
|
||||
image: np.ndarray,
|
||||
max_crops: int,
|
||||
overlap_margins: list[int],
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
image_pooling_w: int,
|
||||
image_pooling_h: int,
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
||||
:return crops, the image crops to processes with the ViT
|
||||
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||
patches in `crops` to pool for that token, masked with -1
|
||||
"""
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
base_image_input_d = image_patch_size
|
||||
pooling_w = image_pooling_w
|
||||
pooling_h = image_pooling_h
|
||||
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
||||
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
||||
|
||||
if crop_mode == "resize":
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
resized_h, resized_w = resize_idx.shape[:2]
|
||||
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
||||
return (
|
||||
np.stack(image_grid, 0),
|
||||
batch_pixels_to_patches(resized, image_patch_size),
|
||||
resize_idx,
|
||||
)
|
||||
|
||||
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
||||
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
||||
|
||||
crop_arr, patch_idx_arr = build_overlapping_crops(
|
||||
image,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
||||
h, w = pooling_idx.shape[:2]
|
||||
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||
|
||||
# Finally do the same for the global image
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
crop_arr = np.concatenate([resized, crop_arr], 0)
|
||||
|
||||
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
resized_h, resized_w = resize_idx.shape[:2]
|
||||
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||
|
||||
# Global image goes first, so the order of patches in previous crops gets increased
|
||||
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
|
||||
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
||||
image_grid = [np.array([resized_h, resized_w, h, w])]
|
||||
|
||||
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
|
||||
|
||||
|
||||
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
||||
max_crops: int | None
|
||||
overlap_margins: list[int] | None
|
||||
crop_mode: str | None
|
||||
patch_size: int | None
|
||||
pooling_size: list[int] | None
|
||||
|
||||
|
||||
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
||||
|
||||
Args:
|
||||
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use when resizing the image.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
max_crops (`int`, *optional*, defaults to `8`):
|
||||
Maximum number of crops to use per image.
|
||||
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
||||
Overlap margins to use.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
||||
The pooling size of the vision adapter.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool = True,
|
||||
max_crops: int = 8,
|
||||
overlap_margins: list[int] = [4, 4],
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
patch_size: int = 14,
|
||||
pooling_size: list[int] = [2, 2],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 378, "width": 378}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
self.size = size
|
||||
|
||||
self.resample = resample
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
self.max_crops = max_crops
|
||||
self.overlap_margins = overlap_margins
|
||||
self.crop_mode = crop_mode
|
||||
self.patch_size = patch_size
|
||||
self.pooling_size = pooling_size
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
max_crops: int | None = None,
|
||||
overlap_margins: list[int] | None = None,
|
||||
crop_mode: str | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_size: list[int] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess.
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
||||
Maximum number of crops to use per image.
|
||||
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
||||
Overlap margins to use.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||
The pooling size of the vision adapter.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
A `BatchFeature` containing the following keys:
|
||||
- `pixel_values`: The preprocessed images.
|
||||
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
||||
- `image_grids`: The image grids.
|
||||
- `image_num_crops`: The number of crops for each image.
|
||||
"""
|
||||
if size is not None:
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
else:
|
||||
size = {**self.size}
|
||||
|
||||
base_image_input_size = [size["height"], size["width"]]
|
||||
|
||||
resample = resample or self.resample
|
||||
image_mean = image_mean or self.image_mean
|
||||
image_std = image_std or self.image_std
|
||||
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||
|
||||
max_crops = max_crops or self.max_crops
|
||||
overlap_margins = overlap_margins or self.overlap_margins
|
||||
crop_mode = crop_mode or self.crop_mode
|
||||
patch_size = patch_size or self.patch_size
|
||||
pooling_size = pooling_size or self.pooling_size
|
||||
|
||||
image_pooling_h, image_pooling_w = pooling_size
|
||||
|
||||
if images is not None:
|
||||
images = self.fetch_images(images)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
batch_grids = []
|
||||
batch_crops = []
|
||||
batch_pooled_patches_idx = []
|
||||
batch_num_crops = []
|
||||
|
||||
for image in images:
|
||||
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||
image,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
patch_size,
|
||||
image_pooling_w,
|
||||
image_pooling_h,
|
||||
crop_mode,
|
||||
)
|
||||
batch_grids.append(image_grid)
|
||||
batch_crops.append(crops)
|
||||
batch_pooled_patches_idx.append(pooled_idx)
|
||||
batch_num_crops.append(crops.shape[0])
|
||||
|
||||
pixel_values = np.concatenate(batch_crops, 0)
|
||||
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
image_grids = np.concatenate(batch_grids, 0)
|
||||
image_num_crops = np.array(batch_num_crops)
|
||||
|
||||
data.update(
|
||||
pixel_values=pixel_values,
|
||||
image_token_pooling=image_token_pooling,
|
||||
image_grids=image_grids,
|
||||
image_num_crops=image_num_crops,
|
||||
)
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
MolmoAct2ImageProcessor.register_for_auto_class()
|
||||
748
src/lerobot/policies/molmoact2/hf_model/inference.py
Normal file
748
src/lerobot/policies/molmoact2/hf_model/inference.py
Normal file
@@ -0,0 +1,748 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Inference utilities for MolmoAct2"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionFlowInputs:
|
||||
trajectory: torch.Tensor
|
||||
context: Any
|
||||
modulations: Sequence[Any]
|
||||
action_dim_is_pad: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionFlowCudaGraph:
|
||||
key: tuple[Any, ...]
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_inputs: _ActionFlowInputs
|
||||
output: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphLayerStage:
|
||||
residual: torch.Tensor
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphPostStage:
|
||||
graph: torch.cuda.CUDAGraph
|
||||
attn_context: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraph:
|
||||
cache_key: tuple[Any, ...]
|
||||
pre_graph: torch.cuda.CUDAGraph
|
||||
token_ids: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
||||
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
||||
output: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphSpec:
|
||||
eligible: bool
|
||||
cache_key_prefix: tuple[Any, ...]
|
||||
num_hidden_layers: int
|
||||
head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
|
||||
if past_key_values is None:
|
||||
return 0
|
||||
seq_len = past_key_values.get_seq_length()
|
||||
if torch.is_tensor(seq_len):
|
||||
return int(seq_len.item())
|
||||
return int(seq_len)
|
||||
|
||||
|
||||
def _cache_max_len_int(past_key_values: Cache | None) -> int:
|
||||
if past_key_values is None:
|
||||
return -1
|
||||
max_len = past_key_values.get_max_cache_shape()
|
||||
if torch.is_tensor(max_len):
|
||||
return int(max_len.item())
|
||||
return int(max_len)
|
||||
|
||||
|
||||
def _iter_cache_key_values(
|
||||
past_key_values: Cache,
|
||||
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
|
||||
layers = getattr(past_key_values, "layers", None)
|
||||
if layers is not None:
|
||||
for layer in layers:
|
||||
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
||||
return
|
||||
for layer in past_key_values:
|
||||
yield layer[0], layer[1]
|
||||
|
||||
|
||||
class _DepthDecodeStaticLayerCache:
|
||||
is_compileable = False
|
||||
is_sliding = False
|
||||
|
||||
def __init__(self, max_cache_len: int) -> None:
|
||||
self.max_cache_len = int(max_cache_len)
|
||||
self.cumulative_length = 0
|
||||
self.keys: torch.Tensor | None = None
|
||||
self.values: torch.Tensor | None = None
|
||||
|
||||
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
||||
bsz, n_heads = key_states.shape[:2]
|
||||
self.keys = torch.empty(
|
||||
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
||||
dtype=key_states.dtype,
|
||||
device=key_states.device,
|
||||
)
|
||||
self.values = torch.empty(
|
||||
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
||||
dtype=value_states.dtype,
|
||||
device=value_states.device,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.keys is None:
|
||||
self._allocate(key_states, value_states)
|
||||
start = self.cumulative_length
|
||||
end = start + key_states.shape[-2]
|
||||
if end > self.max_cache_len:
|
||||
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
|
||||
self.keys[:, :, start:end, :].copy_(key_states)
|
||||
self.values[:, :, start:end, :].copy_(value_states)
|
||||
self.cumulative_length = end
|
||||
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
||||
|
||||
def get_seq_length(self) -> int:
|
||||
return self.cumulative_length
|
||||
|
||||
def get_max_cache_shape(self) -> int:
|
||||
return -1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.cumulative_length = 0
|
||||
|
||||
|
||||
class _DepthDecodeStaticCache(Cache):
|
||||
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
super().__init__(
|
||||
layers=[
|
||||
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
||||
for _ in range(text_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def get_seq_length(self, layer_idx: int = 0) -> int:
|
||||
return self.layers[layer_idx].get_seq_length()
|
||||
|
||||
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
||||
return self.layers[layer_idx].get_max_cache_shape()
|
||||
|
||||
def reset(self) -> None:
|
||||
for layer in self.layers:
|
||||
layer.reset()
|
||||
|
||||
|
||||
class ActionCudaGraphManager:
|
||||
def __init__(self, model: Any) -> None:
|
||||
self.model = model
|
||||
self.enabled = True
|
||||
self.action_flow_graph: _ActionFlowCudaGraph | None = None
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self.enabled = bool(enabled)
|
||||
|
||||
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
||||
action_model = self.model
|
||||
if not self.enabled:
|
||||
return False
|
||||
if action_model.training or action_model._require_action_expert().training:
|
||||
return False
|
||||
if inputs.trajectory.device.type != "cuda":
|
||||
return False
|
||||
|
||||
def all_on_cuda():
|
||||
yield inputs.trajectory
|
||||
for k, v in inputs.context.kv_contexts:
|
||||
yield k
|
||||
yield v
|
||||
for t in (
|
||||
inputs.context.cross_mask,
|
||||
inputs.context.self_mask,
|
||||
inputs.context.valid_action,
|
||||
inputs.action_dim_is_pad,
|
||||
):
|
||||
if t is not None:
|
||||
yield t
|
||||
if inputs.context.rope_cache is not None:
|
||||
yield from inputs.context.rope_cache
|
||||
for step in inputs.modulations:
|
||||
yield step.conditioning
|
||||
for block_modulation in step.block_modulations:
|
||||
yield from block_modulation
|
||||
yield from step.final_modulation
|
||||
|
||||
return all(t.device.type == "cuda" for t in all_on_cuda())
|
||||
|
||||
def run_action_flow(
|
||||
self,
|
||||
inputs: _ActionFlowInputs,
|
||||
steps: int,
|
||||
run_loop,
|
||||
) -> torch.Tensor:
|
||||
key = _cuda_graph_key(inputs, steps)
|
||||
cache = self.action_flow_graph
|
||||
if cache is None or cache.key != key:
|
||||
static_inputs = _clone_static_inputs(inputs)
|
||||
graph, output = _capture_cuda_graph(
|
||||
lambda: run_loop(static_inputs, steps),
|
||||
inputs.trajectory.device,
|
||||
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
||||
)
|
||||
cache = _ActionFlowCudaGraph(
|
||||
key=key,
|
||||
graph=graph,
|
||||
static_inputs=static_inputs,
|
||||
output=output,
|
||||
)
|
||||
self.action_flow_graph = cache
|
||||
else:
|
||||
_copy_inputs_(cache.static_inputs, inputs)
|
||||
|
||||
cache.graph.replay()
|
||||
return cache.output.clone()
|
||||
|
||||
|
||||
class DepthDecodeCudaGraphManager:
|
||||
def __init__(self, model: Any) -> None:
|
||||
self.model = model
|
||||
self.backbone = model.model
|
||||
self.enabled = True
|
||||
self.graph: _DepthDecodeCudaGraph | None = None
|
||||
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self.enabled = bool(enabled)
|
||||
|
||||
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
||||
return _DepthDecodeStaticCache(
|
||||
config=self.model.config.text_config,
|
||||
max_cache_len=max_cache_len,
|
||||
)
|
||||
|
||||
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
||||
static = self.graph_spec
|
||||
if static is None:
|
||||
cfg = self.backbone.transformer.config
|
||||
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
||||
static = _DepthDecodeCudaGraphSpec(
|
||||
eligible=(
|
||||
not cfg.norm_after
|
||||
and cfg.rope_scaling_layers is None
|
||||
and getattr(rotary_emb, "rope_type", None) == "default"
|
||||
and cfg._attn_implementation == "sdpa"
|
||||
),
|
||||
cache_key_prefix=(
|
||||
cfg.hidden_size,
|
||||
cfg.num_attention_heads,
|
||||
cfg.num_key_value_heads,
|
||||
cfg.head_dim,
|
||||
cfg.num_hidden_layers,
|
||||
cfg.use_qk_norm,
|
||||
cfg.qk_norm_type,
|
||||
cfg._attn_implementation,
|
||||
),
|
||||
num_hidden_layers=cfg.num_hidden_layers,
|
||||
head_dim=cfg.head_dim,
|
||||
num_attention_heads=cfg.num_attention_heads,
|
||||
)
|
||||
self.graph_spec = static
|
||||
return static
|
||||
|
||||
def can_use(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> bool:
|
||||
if not self.enabled or self.model.training or self.backbone.transformer.training:
|
||||
return False
|
||||
if next_input_ids.device.type != "cuda":
|
||||
return False
|
||||
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
|
||||
return False
|
||||
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
||||
return False
|
||||
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
|
||||
return False
|
||||
return self._depth_decode_spec().eligible
|
||||
|
||||
def _depth_decode_key(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> tuple[Any, ...]:
|
||||
device = next_input_ids.device
|
||||
return (
|
||||
self._depth_decode_spec().cache_key_prefix,
|
||||
device.type,
|
||||
device.index,
|
||||
self.model.lm_head.weight.dtype,
|
||||
attention_bias.shape[-1],
|
||||
)
|
||||
|
||||
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
|
||||
emb = self.backbone.transformer.rotary_emb
|
||||
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
||||
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
||||
|
||||
def _depth_decode_pre_layer(
|
||||
self,
|
||||
layer_idx: int,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
block = self.backbone.transformer.blocks[layer_idx]
|
||||
attention = block.self_attn
|
||||
residual = hidden_states
|
||||
hidden_states = block.attn_norm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, attention.head_dim)
|
||||
qkv = attention.att_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
||||
value_states = value_states.view(hidden_shape)
|
||||
|
||||
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
||||
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
||||
|
||||
if apply_qk_norm and not norm_after_view:
|
||||
query_states = attention.q_norm(query_states)
|
||||
key_states = attention.k_norm(key_states)
|
||||
|
||||
query_states = query_states.view(hidden_shape)
|
||||
key_states = key_states.view(hidden_shape)
|
||||
|
||||
if norm_after_view:
|
||||
query_states = attention.q_norm(query_states)
|
||||
key_states = attention.k_norm(key_states)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
return residual, query_states, key_states, value_states
|
||||
|
||||
def _depth_decode_pre0(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
||||
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
||||
|
||||
def _depth_decode_post_layer(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
block = self.backbone.transformer.blocks[layer_idx]
|
||||
attention = block.self_attn
|
||||
input_shape = residual.shape[:-1]
|
||||
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = attention.attn_out(attn_output)
|
||||
hidden_states = residual + block.dropout(attn_output)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = block.ff_norm(hidden_states)
|
||||
hidden_states = block.mlp(hidden_states)
|
||||
hidden_states = residual + block.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _depth_decode_post_and_pre_next(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
||||
|
||||
def _depth_decode_last_post(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||
return self.backbone.transformer.ln_f(hidden_states)
|
||||
|
||||
def _build_depth_decode_graph(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_length: int,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> _DepthDecodeCudaGraph:
|
||||
text_config = self.backbone.transformer.config
|
||||
device = next_input_ids.device
|
||||
dtype = self.model.lm_head.weight.dtype
|
||||
static = self._depth_decode_spec()
|
||||
num_layers = static.num_hidden_layers
|
||||
head_dim = static.head_dim
|
||||
max_cache_len = int(attention_bias.shape[-1])
|
||||
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
||||
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
|
||||
|
||||
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
||||
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
||||
sin = torch.empty_like(cos)
|
||||
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
||||
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
||||
|
||||
token_ids.copy_(next_input_ids)
|
||||
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
||||
|
||||
pre_graph, pre_output = _capture_cuda_graph(
|
||||
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
||||
device,
|
||||
)
|
||||
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
||||
post_graphs = []
|
||||
for layer_idx in range(num_layers - 1):
|
||||
stage = stages[-1]
|
||||
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||
graph, output = _capture_cuda_graph(
|
||||
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
||||
self._depth_decode_post_and_pre_next(
|
||||
layer_idx,
|
||||
stage.residual,
|
||||
attn_context,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
),
|
||||
device,
|
||||
)
|
||||
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
|
||||
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
||||
|
||||
last_stage = stages[-1]
|
||||
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||
last_graph, last_output = _capture_cuda_graph(
|
||||
lambda: self._depth_decode_last_post(
|
||||
num_layers - 1,
|
||||
last_stage.residual,
|
||||
last_attn_context,
|
||||
),
|
||||
device,
|
||||
)
|
||||
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
|
||||
return _DepthDecodeCudaGraph(
|
||||
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
||||
pre_graph=pre_graph,
|
||||
token_ids=token_ids,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
positions=positions,
|
||||
stages=tuple(stages),
|
||||
post_graphs=tuple(post_graphs),
|
||||
output=last_output,
|
||||
)
|
||||
|
||||
def _get_depth_decode_graph(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_length: int,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> _DepthDecodeCudaGraph:
|
||||
key = self._depth_decode_key(next_input_ids, attention_bias)
|
||||
decode_graph = self.graph
|
||||
if decode_graph is None or decode_graph.cache_key != key:
|
||||
decode_graph = self._build_depth_decode_graph(
|
||||
next_input_ids,
|
||||
past_length=past_length,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
self.graph = decode_graph
|
||||
else:
|
||||
decode_graph.token_ids.copy_(next_input_ids)
|
||||
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
|
||||
return decode_graph
|
||||
|
||||
def _run_depth_decode_attention_core(
|
||||
self,
|
||||
layer_idx: int,
|
||||
stage: _DepthDecodeCudaGraphLayerStage,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(
|
||||
stage.key,
|
||||
stage.value,
|
||||
layer_idx,
|
||||
cache_kwargs,
|
||||
)
|
||||
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
||||
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
stage.query,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_bias,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
return attn_output.transpose(1, 2)
|
||||
|
||||
def run(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
past_length: int,
|
||||
) -> tuple[torch.Tensor, Cache]:
|
||||
end = past_length + 1
|
||||
decode_graph = self._get_depth_decode_graph(
|
||||
next_input_ids,
|
||||
past_length=past_length,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
cache_position = decode_graph.positions[past_length:end]
|
||||
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
||||
|
||||
decode_graph.pre_graph.replay()
|
||||
|
||||
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
||||
attn_context = self._run_depth_decode_attention_core(
|
||||
layer_idx,
|
||||
decode_graph.stages[layer_idx],
|
||||
past_key_values=past_key_values,
|
||||
attention_bias=attention_bias_q,
|
||||
cache_position=cache_position,
|
||||
cos=decode_graph.cos,
|
||||
sin=decode_graph.sin,
|
||||
)
|
||||
post_graph.attn_context.copy_(attn_context)
|
||||
post_graph.graph.replay()
|
||||
|
||||
return decode_graph.output, past_key_values
|
||||
|
||||
|
||||
def _cuda_graph_tensor_signature(
|
||||
tensor: torch.Tensor | None,
|
||||
) -> tuple[Any, ...] | None:
|
||||
if tensor is None:
|
||||
return None
|
||||
return (
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
str(tensor.dtype),
|
||||
str(tensor.device),
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return (
|
||||
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
||||
sig(context.cross_mask),
|
||||
sig(context.self_mask),
|
||||
sig(context.valid_action),
|
||||
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return tuple(
|
||||
(
|
||||
sig(step.conditioning),
|
||||
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
|
||||
tuple(sig(t) for t in step.final_modulation),
|
||||
)
|
||||
for step in modulations
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return (
|
||||
sig(inputs.trajectory),
|
||||
_cuda_graph_context_signature(inputs.context),
|
||||
_cuda_graph_modulation_signature(inputs.modulations),
|
||||
sig(inputs.action_dim_is_pad),
|
||||
int(steps),
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if tensor is None:
|
||||
return None
|
||||
static = torch.empty_strided(
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
device=tensor.device,
|
||||
dtype=tensor.dtype,
|
||||
)
|
||||
static.copy_(tensor)
|
||||
return static
|
||||
|
||||
|
||||
def _clone_static_context(context: Any) -> Any:
|
||||
rope_cache = None
|
||||
if context.rope_cache is not None:
|
||||
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
||||
return context.__class__(
|
||||
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
|
||||
cross_mask=_clone_static_tensor(context.cross_mask),
|
||||
self_mask=_clone_static_tensor(context.self_mask),
|
||||
valid_action=_clone_static_tensor(context.valid_action),
|
||||
rope_cache=rope_cache,
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
||||
return tuple(
|
||||
step.__class__(
|
||||
conditioning=_clone_static_tensor(step.conditioning),
|
||||
block_modulations=tuple(
|
||||
tuple(_clone_static_tensor(t) for t in block_modulation)
|
||||
for block_modulation in step.block_modulations
|
||||
),
|
||||
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
|
||||
)
|
||||
for step in modulations
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
||||
return _ActionFlowInputs(
|
||||
trajectory=_clone_static_tensor(inputs.trajectory),
|
||||
context=_clone_static_context(inputs.context),
|
||||
modulations=_clone_static_modulations(inputs.modulations),
|
||||
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
||||
)
|
||||
|
||||
|
||||
def _copy_context_(dst: Any, src: Any) -> None:
|
||||
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
||||
dst_k.copy_(src_k)
|
||||
dst_v.copy_(src_v)
|
||||
if src.cross_mask is not None:
|
||||
dst.cross_mask.copy_(src.cross_mask)
|
||||
if src.self_mask is not None:
|
||||
dst.self_mask.copy_(src.self_mask)
|
||||
if src.valid_action is not None:
|
||||
dst.valid_action.copy_(src.valid_action)
|
||||
if src.rope_cache is not None:
|
||||
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
||||
dst_tensor.copy_(src_tensor)
|
||||
|
||||
|
||||
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
||||
dst.trajectory.copy_(src.trajectory)
|
||||
_copy_context_(dst.context, src.context)
|
||||
if src.action_dim_is_pad is not None:
|
||||
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
||||
|
||||
|
||||
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _capture_cuda_graph(
|
||||
fn,
|
||||
device: torch.device,
|
||||
*,
|
||||
after_warmup=None,
|
||||
) -> tuple[torch.cuda.CUDAGraph, Any]:
|
||||
warmup_stream = torch.cuda.Stream(device=device)
|
||||
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(warmup_stream):
|
||||
fn()
|
||||
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
||||
if after_warmup is not None:
|
||||
after_warmup()
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
output = fn()
|
||||
return graph, output
|
||||
4591
src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py
Normal file
4591
src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
431
src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py
Normal file
431
src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py
Normal file
@@ -0,0 +1,431 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
Processor class for MolmoAct2.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import dataclasses
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.video_utils import VideoInput
|
||||
from transformers.processing_utils import (
|
||||
Unpack,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
||||
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
||||
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
||||
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
||||
IM_START_TOKEN = f"<im_start>"
|
||||
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
||||
FRAME_START_TOKEN = f"<frame_start>"
|
||||
IM_END_TOKEN = f"<im_end>"
|
||||
FRAME_END_TOKEN = f"<frame_end>"
|
||||
IM_COL_TOKEN = f"<im_col>"
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
VIDEO_PROMPT = "<|video|>"
|
||||
|
||||
IMAGE_TOKENS = [
|
||||
IMAGE_PATCH_TOKEN,
|
||||
IM_COL_TOKEN,
|
||||
IM_START_TOKEN,
|
||||
LOW_RES_IMAGE_START_TOKEN,
|
||||
FRAME_START_TOKEN,
|
||||
IM_END_TOKEN,
|
||||
FRAME_END_TOKEN,
|
||||
IMAGE_LOW_RES_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
"""MolmoAct2 processor kwargs"""
|
||||
|
||||
images_kwargs: MolmoAct2ImagesKwargs
|
||||
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"return_mm_token_type_ids": True,
|
||||
},
|
||||
"videos_kwargs": {"return_metadata": True},
|
||||
}
|
||||
|
||||
|
||||
class MolmoAct2Processor(ProcessorMixin):
|
||||
attributes = ["image_processor", "video_processor", "tokenizer"]
|
||||
optional_attributes = [
|
||||
"chat_template",
|
||||
"time_mode",
|
||||
"image_use_col_tokens",
|
||||
"use_single_crop_col_tokens",
|
||||
"use_single_crop_start_token",
|
||||
"video_use_col_tokens",
|
||||
"use_frame_special_tokens",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: MolmoAct2ImageProcessor = None,
|
||||
video_processor: MolmoAct2VideoProcessor = None,
|
||||
tokenizer: AutoTokenizer = None,
|
||||
chat_template: str | None = None,
|
||||
image_use_col_tokens: bool | None = True,
|
||||
use_single_crop_col_tokens: bool | None = None,
|
||||
use_single_crop_start_token: bool | None = True,
|
||||
video_use_col_tokens: bool | None = False,
|
||||
use_frame_special_tokens: bool | None = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
image_processor,
|
||||
video_processor,
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
self.image_use_col_tokens = image_use_col_tokens
|
||||
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
||||
self.use_single_crop_start_token = use_single_crop_start_token
|
||||
self.video_use_col_tokens = video_use_col_tokens
|
||||
self.use_frame_special_tokens = use_frame_special_tokens
|
||||
|
||||
self.image_placeholder_token = IMAGE_PROMPT
|
||||
self.video_placeholder_token = VIDEO_PROMPT
|
||||
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
|
||||
|
||||
def get_image_tokens(self, image_grid: np.ndarray):
|
||||
resized_h, resized_w, height, width = image_grid
|
||||
if int(height) == 0 or int(width) == 0:
|
||||
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||
use_single_crop_col_tokens = (
|
||||
self.image_use_col_tokens
|
||||
if self.use_single_crop_col_tokens is None
|
||||
else self.use_single_crop_col_tokens
|
||||
)
|
||||
if use_single_crop_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[IM_START_TOKEN],
|
||||
np.tile(per_row, [resized_h]),
|
||||
[IM_END_TOKEN],
|
||||
]
|
||||
return np.concatenate(joint)
|
||||
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
||||
if self.image_use_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[IM_START_TOKEN],
|
||||
np.tile(per_row, [height]),
|
||||
[IM_END_TOKEN],
|
||||
]
|
||||
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||
use_single_crop_col_tokens = (
|
||||
self.image_use_col_tokens
|
||||
if self.use_single_crop_col_tokens is None
|
||||
else self.use_single_crop_col_tokens
|
||||
)
|
||||
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
|
||||
if use_single_crop_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[image_start_token],
|
||||
np.tile(per_row, [resized_h]),
|
||||
[IM_END_TOKEN],
|
||||
] + joint
|
||||
|
||||
return np.concatenate(joint)
|
||||
|
||||
def get_video_string(
|
||||
self,
|
||||
video_grid: np.ndarray,
|
||||
timestamps: np.ndarray,
|
||||
):
|
||||
if self.use_frame_special_tokens:
|
||||
start_token_id = FRAME_START_TOKEN
|
||||
end_token_id = FRAME_END_TOKEN
|
||||
else:
|
||||
start_token_id = IM_START_TOKEN
|
||||
end_token_id = IM_END_TOKEN
|
||||
|
||||
num_frames, h, w = video_grid
|
||||
video_string: str = ""
|
||||
for frame_idx, frame_time in enumerate(timestamps):
|
||||
# `per-frame-compact` time mode
|
||||
prev_space = " " if frame_idx > 0 else ""
|
||||
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
||||
|
||||
video_string += frame_prefix
|
||||
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
||||
if self.video_use_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
extra_tokens = np.tile(per_row, [h])
|
||||
video_tokens = [
|
||||
[start_token_id],
|
||||
extra_tokens,
|
||||
[end_token_id],
|
||||
]
|
||||
video_string += "".join(np.concatenate(video_tokens, 0))
|
||||
|
||||
return video_string
|
||||
|
||||
def insert_bos(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
attention_mask: np.ndarray,
|
||||
bos_token_id: int,
|
||||
pad_token_id: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_ids: [B, S] array with left padding
|
||||
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
||||
bos_token_id: int
|
||||
pad_token_id: int
|
||||
Returns:
|
||||
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
||||
attention_mask_out: same shape as input_ids_out
|
||||
"""
|
||||
|
||||
need_to_expand = len(input_ids.shape) == 1
|
||||
if need_to_expand:
|
||||
input_ids = input_ids[None, :]
|
||||
attention_mask = attention_mask[None, :]
|
||||
|
||||
B, S = input_ids.shape
|
||||
|
||||
# Handle zero-length sequence
|
||||
if S == 0:
|
||||
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
||||
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
||||
if need_to_expand:
|
||||
new_input_ids = new_input_ids[0]
|
||||
new_attention_mask = new_attention_mask[0]
|
||||
return new_input_ids, new_attention_mask
|
||||
|
||||
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
||||
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
||||
|
||||
if bos_already_present:
|
||||
if need_to_expand:
|
||||
input_ids = input_ids[0]
|
||||
attention_mask = attention_mask[0]
|
||||
return input_ids, attention_mask
|
||||
else:
|
||||
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
|
||||
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
|
||||
|
||||
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
||||
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
||||
tgt_idx = src_idx + 1 # shit right
|
||||
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
||||
|
||||
# flatten valid_positions
|
||||
flat_vals = input_ids[valid_mask]
|
||||
flat_batch = batch_idx[valid_mask]
|
||||
flat_tgt = tgt_idx[valid_mask]
|
||||
|
||||
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
||||
new_attention_mask[flat_batch, flat_tgt] = 1
|
||||
|
||||
insert_pos = first_valid_index
|
||||
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
||||
new_attention_mask[np.arange(B), insert_pos] = 1
|
||||
|
||||
if need_to_expand:
|
||||
new_input_ids = new_input_ids[0]
|
||||
new_attention_mask = new_attention_mask[0]
|
||||
|
||||
return new_input_ids, new_attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
images: ImageInput = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
|
||||
Args:
|
||||
text (`str`, `list[str]`, `list[list[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
||||
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
||||
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
||||
- `"timestamps"`: `np.ndarray` of shape (T,)
|
||||
- `"sampled_fps"`: `float` (optional)
|
||||
- `"sampling_augmentation"`: `str` (optional)
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
||||
Returned when `images` is not `None`.
|
||||
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
||||
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
||||
Returned when `videos` is not `None`.
|
||||
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
MolmoAct2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
image_grids = image_inputs["image_grids"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grids = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_grids = videos_inputs["video_grids"]
|
||||
# If user has not requested video metadata, pop it
|
||||
if "return_metadata" not in kwargs:
|
||||
video_metadata = videos_inputs.pop("video_metadata")
|
||||
else:
|
||||
video_metadata = videos_inputs["video_metadata"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grids = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
|
||||
if image_grids is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
num_images = text[i].count(self.image_placeholder_token)
|
||||
image_grids_i = image_grids[index : index + num_images]
|
||||
for image_grid in image_grids_i:
|
||||
image_tokens = self.get_image_tokens(image_grid)
|
||||
image_string = "".join(image_tokens)
|
||||
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
||||
index += num_images
|
||||
|
||||
if video_grids is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
num_videos = text[i].count(self.video_placeholder_token)
|
||||
assert num_videos in {0, 1}, "At most one video is supported for now"
|
||||
video_grids_i = video_grids[index : index + num_videos]
|
||||
metadata_i = video_metadata[index : index + num_videos]
|
||||
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
||||
video_string = self.get_video_string(
|
||||
video_grid,
|
||||
metadata.timestamps,
|
||||
)
|
||||
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
||||
index += num_videos
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
input_ids = text_inputs["input_ids"]
|
||||
attention_mask = text_inputs["attention_mask"]
|
||||
|
||||
input_ids = np.array(input_ids)
|
||||
attention_mask = np.array(attention_mask)
|
||||
|
||||
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
||||
input_ids, attention_mask = self.insert_bos(
|
||||
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
if return_mm_token_type_ids:
|
||||
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
||||
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
||||
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
||||
|
||||
text_inputs["input_ids"] = input_ids.tolist()
|
||||
text_inputs["attention_mask"] = attention_mask.tolist()
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs, **videos_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`list[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
MolmoAct2Processor.register_for_auto_class()
|
||||
@@ -0,0 +1,997 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Video processor class for MolmoAct2"""
|
||||
|
||||
from functools import partial
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import redirect_stdout
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.video_utils import (
|
||||
VideoInput,
|
||||
is_valid_video,
|
||||
make_batched_videos,
|
||||
make_batched_metadata,
|
||||
VideoMetadata,
|
||||
)
|
||||
from transformers.processing_utils import Unpack, VideosKwargs
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import (
|
||||
is_av_available,
|
||||
is_decord_available,
|
||||
is_torchcodec_available,
|
||||
is_yt_dlp_available,
|
||||
TensorType,
|
||||
logging,
|
||||
to_numpy,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_VIDEO_FPS = 8
|
||||
|
||||
|
||||
def normalize_image(
|
||||
image: np.ndarray,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
) -> np.ndarray:
|
||||
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
desired_output_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
if len(image.shape) == 3:
|
||||
is_video = False
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
else:
|
||||
is_video = True
|
||||
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
|
||||
dtype = image.dtype
|
||||
if torch.is_floating_point(image):
|
||||
in_min = 0.0
|
||||
in_max = 1.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||
|
||||
resized = resized.to(torch.float32)
|
||||
resized = (resized - in_min) / (in_max - in_min)
|
||||
|
||||
if is_video:
|
||||
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
|
||||
else:
|
||||
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def build_resized_image(
|
||||
image: np.ndarray,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
resized = resize_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
)
|
||||
resized = normalize_image(resized, image_mean, image_std)
|
||||
if len(resized.shape) == 3:
|
||||
resized = np.expand_dims(resized, 0)
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||
return resized, resize_idx
|
||||
|
||||
|
||||
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||
if len(array.shape) == 3:
|
||||
n_crops, h, w = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||
return array
|
||||
else:
|
||||
n_crops, h, w, c = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||
return array
|
||||
|
||||
|
||||
def arange_for_pooling(
|
||||
idx_arr: np.ndarray,
|
||||
pool_h: int,
|
||||
pool_w: int,
|
||||
) -> np.ndarray:
|
||||
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||
idx_arr = np.pad(
|
||||
idx_arr,
|
||||
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||
mode="constant",
|
||||
constant_values=-1,
|
||||
)
|
||||
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||
|
||||
|
||||
def image_to_patches_and_grids(
|
||||
image: ImageInput,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
image_pooling_w: int,
|
||||
image_pooling_h: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:return image_grids, the shape of each image after pooling
|
||||
:return crops, the image crops to processes with the ViT
|
||||
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||
patches in `crops` to pool for that token, masked with -1
|
||||
"""
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
pooling_w = image_pooling_w
|
||||
pooling_h = image_pooling_h
|
||||
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
h, w = pooling_idx.shape[:2]
|
||||
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||
image_grid = [h, w]
|
||||
return (
|
||||
image_grid,
|
||||
batch_pixels_to_patches(resized, image_patch_size),
|
||||
pooling_idx,
|
||||
)
|
||||
|
||||
|
||||
def get_candidate_target_fps(
|
||||
video_fps: int | float,
|
||||
sampling_fps: int | float,
|
||||
max_fps: int | float = MAX_VIDEO_FPS,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
|
||||
|
||||
Examples:
|
||||
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
||||
[2, 6]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
||||
[1, 5]
|
||||
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
||||
[2]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
|
||||
"""
|
||||
video_fps = int(video_fps)
|
||||
sampling_fps = int(sampling_fps)
|
||||
max_fps = int(max_fps)
|
||||
|
||||
if sampling_fps is None:
|
||||
raise ValueError("sampling_fps must be provided")
|
||||
if video_fps <= 0 or sampling_fps <= 0:
|
||||
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
|
||||
if video_fps % sampling_fps != 0:
|
||||
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
|
||||
|
||||
candidates = []
|
||||
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
||||
if candidate > max_fps:
|
||||
break
|
||||
if video_fps % candidate == 0:
|
||||
candidates.append(float(candidate))
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def read_video_decord(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import from decord
|
||||
import importlib
|
||||
|
||||
decord = importlib.import_module("decord")
|
||||
|
||||
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
|
||||
duration = time_stamps[-1][1] - time_stamps[0][0]
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="decord",
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
offset = time_stamps[0, 0]
|
||||
|
||||
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
|
||||
ix = np.minimum(ix, len(time_stamps) - 1)
|
||||
|
||||
video = vr.get_batch(ix).asnumpy()
|
||||
metadata.update(
|
||||
{
|
||||
"frames_indices": target_timestamps * video_fps,
|
||||
"height": video.shape[1],
|
||||
"width": video.shape[2],
|
||||
}
|
||||
)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchcodec(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using torchcodec decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import torchcodec
|
||||
import importlib
|
||||
|
||||
torchcodec = importlib.import_module("torchcodec")
|
||||
|
||||
decoder = torchcodec.decoders.VideoDecoder(
|
||||
video_path,
|
||||
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
||||
seek_mode="exact",
|
||||
# Allow FFmpeg decide on the number of threads for efficiency
|
||||
num_ffmpeg_threads=0,
|
||||
)
|
||||
# If the first frame starts at > 0, we effectively clip the video starting at that time
|
||||
# since (most) video players would also skip to that time
|
||||
time_offset = decoder.metadata.begin_stream_seconds_from_content
|
||||
# Note this duration does assume we started playing at `time_offset`
|
||||
duration = decoder.metadata.duration_seconds
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=decoder.metadata.num_frames,
|
||||
fps=decoder.metadata.average_fps,
|
||||
duration=duration,
|
||||
video_backend="torchcodec",
|
||||
height=decoder.metadata.height,
|
||||
width=decoder.metadata.width,
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
|
||||
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
|
||||
# out-of-bounds, to handle this we sanity check then clip them
|
||||
assert all(x >= 0 for x in target_timestamps)
|
||||
assert all(x < duration + 1e-6 for x in target_timestamps)
|
||||
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
|
||||
# exact boundary value, we should still get the first/last frame anyway
|
||||
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
|
||||
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
|
||||
# Note we avoid using numpy ops here to reduce floating precision issues
|
||||
timestamps = [x + time_offset for x in target_timestamps]
|
||||
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
|
||||
|
||||
video = (
|
||||
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
|
||||
) # Convert to THWC format
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
metadata.frames_indices = target_timestamps * metadata.fps
|
||||
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_pyav(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using the PyAV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import torchcodec
|
||||
import importlib
|
||||
|
||||
av = importlib.import_module("av")
|
||||
|
||||
with av.open(video_path) as container:
|
||||
video_stream = container.streams.video[0]
|
||||
fps = video_stream.average_rate or video_stream.guessed_rate
|
||||
it = container.decode(video=0)
|
||||
frames = list(it)
|
||||
|
||||
stream = container.streams.video[0]
|
||||
start = frames[0].pts * stream.time_base
|
||||
container_end = stream.duration
|
||||
if container_end is not None:
|
||||
container_end *= stream.time_base
|
||||
if container_end is None or container_end < frames[-1].pts:
|
||||
# Some problem with stream duration, so use the frame PTS directly
|
||||
# and guess the duration of the last frame
|
||||
end = frames[-1].pts * stream.time_base + 1 / fps
|
||||
else:
|
||||
end = container_end
|
||||
duration = float(end - start)
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=len(frames),
|
||||
fps=float(fps),
|
||||
duration=float(duration),
|
||||
video_backend="pyav",
|
||||
height=video_stream.height,
|
||||
width=video_stream.width,
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
offset = float(start)
|
||||
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
|
||||
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
|
||||
indices = np.minimum(indices, len(end_time_stamps) - 1)
|
||||
|
||||
video = np.stack(
|
||||
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
metadata.frames_indices = target_timestamps * fps
|
||||
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
"decord": read_video_decord,
|
||||
"torchcodec": read_video_torchcodec,
|
||||
"pyav": read_video_pyav,
|
||||
}
|
||||
|
||||
|
||||
def load_video(
|
||||
video: VideoInput,
|
||||
backend: str = "decord",
|
||||
sample_timestamps_fn: Callable | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
video (`VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
backend (`str`, *optional*, defaults to `"decord"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
"""
|
||||
|
||||
# Early exit if provided an array or `PIL` frames
|
||||
if not isinstance(video, str):
|
||||
metadata = [None] * len(video)
|
||||
return video, metadata
|
||||
|
||||
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
# Lazy import from yt_dlp
|
||||
import importlib
|
||||
|
||||
yt_dlp = importlib.import_module("yt_dlp")
|
||||
|
||||
buffer = BytesIO()
|
||||
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
|
||||
f.download([video])
|
||||
bytes_obj = buffer.getvalue()
|
||||
file_obj = BytesIO(bytes_obj)
|
||||
elif video.startswith("http://") or video.startswith("https://"):
|
||||
file_obj = BytesIO(requests.get(video, timeout=10).content)
|
||||
elif os.path.isfile(video):
|
||||
file_obj = video
|
||||
else:
|
||||
raise TypeError(
|
||||
"Incorrect format used for video. Should be an url linking to an video or a local path."
|
||||
)
|
||||
|
||||
# can also load with decord, but not cv2/torchvision
|
||||
# both will fail in case of url links
|
||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||
if video_is_url and backend == "opencv":
|
||||
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
||||
|
||||
if (
|
||||
(not is_decord_available() and backend == "decord")
|
||||
or (not is_torchcodec_available() and backend == "torchcodec")
|
||||
or (not is_av_available() and backend == "pyav")
|
||||
):
|
||||
raise ImportError(
|
||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||
f"Make sure to install {backend} before loading the video."
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def get_target_fps(
|
||||
video_fps: float,
|
||||
max_frames: int,
|
||||
total_frames: int,
|
||||
frame_sample_mode: str,
|
||||
candidate_target_fps: tuple[float],
|
||||
) -> float:
|
||||
"""
|
||||
Get the target fps that best spans the video and has the most frames sampled
|
||||
"""
|
||||
num_frames_sampled = 0
|
||||
selected_target_fps = None
|
||||
for target_fps in candidate_target_fps:
|
||||
step_size = max(int(video_fps / target_fps), 1)
|
||||
num_frames_sampled_at_fps = int(total_frames / step_size)
|
||||
if num_frames_sampled == 0:
|
||||
if "uniform" in frame_sample_mode:
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
break
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
|
||||
else:
|
||||
# the candidate sampling fps increases so frame count can't decrease
|
||||
assert num_frames_sampled <= num_frames_sampled_at_fps
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
# choose the sampling fps that spans the video
|
||||
continue
|
||||
|
||||
elif num_frames_sampled_at_fps > num_frames_sampled:
|
||||
# both are less than max_frames, choose the one with higher density of frames sampled
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
return selected_target_fps
|
||||
|
||||
|
||||
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
|
||||
if selected_target_fps is None:
|
||||
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
|
||||
else:
|
||||
step_size = max(int(video_fps / selected_target_fps), 1)
|
||||
frame_indices = np.arange(0, total_frames, step_size)
|
||||
if len(frame_indices) > max_frames:
|
||||
frame_indices = frame_indices[:max_frames]
|
||||
return selected_target_fps, frame_indices
|
||||
|
||||
|
||||
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
|
||||
patch_size: int | None
|
||||
pooling_size: list[int] | None
|
||||
frame_sample_mode: str | None
|
||||
max_fps: int | None
|
||||
sampling_fps: int | None
|
||||
|
||||
|
||||
class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
size = {"height": 378, "width": 378}
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
patch_size = 14
|
||||
pooling_size = [3, 3]
|
||||
do_sample_frames = True
|
||||
frame_sample_mode = "uniform_last_frame"
|
||||
max_fps = 2
|
||||
sampling_fps = 2
|
||||
valid_kwargs = MolmoAct2VideoProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
if self.size is not None and (
|
||||
self.size.get("height", None) is None or self.size.get("width", None) is None
|
||||
):
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
size: SizeDict | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if size is not None and ("height" not in size or "width" not in size):
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
return super()._further_process_kwargs(size=size, **kwargs)
|
||||
|
||||
def sample_times(
|
||||
self,
|
||||
metadata: VideoMetadata,
|
||||
frame_sample_mode: str,
|
||||
num_frames: int,
|
||||
max_fps: int | None = None,
|
||||
sampling_fps: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Time-based sampling if an array video is passed
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||
frame_sample_mode (`str`, *optional*):
|
||||
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
||||
num_frames (`int`, *optional*):
|
||||
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||
man_fps (`int`, *optional*):
|
||||
Maximum frames per second to sample.
|
||||
sampling_fps (`int`, *optional*):
|
||||
Sampling frames per second. Defaults to `self.sampling_fps`.
|
||||
Used when `frame_sample_mode` is `"fps"`.
|
||||
"""
|
||||
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
||||
num_frames = num_frames or self.num_frames
|
||||
sampling_fps = sampling_fps or self.sampling_fps
|
||||
|
||||
duration = metadata.duration or metadata.total_num_frames / metadata.fps
|
||||
if frame_sample_mode == "fps":
|
||||
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
||||
# Try larger and larger FPSs until we hit one that can't span the video
|
||||
target_fps = candidate_target_fps[0]
|
||||
for candidate_fps in candidate_target_fps[1:]:
|
||||
if num_frames / candidate_fps < duration:
|
||||
break
|
||||
target_fps = candidate_fps
|
||||
times = np.arange(0, num_frames) / target_fps
|
||||
times = times[times < duration]
|
||||
return times
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
if max_fps is not None:
|
||||
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
|
||||
if max_duration < duration:
|
||||
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
|
||||
else:
|
||||
times = np.arange(0.0, stop=duration, step=1 / max_fps)
|
||||
times = np.concatenate([times, [duration]], axis=0)
|
||||
assert len(times) <= num_frames
|
||||
else:
|
||||
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
|
||||
return times
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
def sample_frames(
|
||||
self,
|
||||
metadata: VideoMetadata,
|
||||
frame_sample_mode: str | None = None,
|
||||
num_frames: int | None = None,
|
||||
max_fps: int | None = None,
|
||||
sampling_fps: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Frame-based sampling if an array video is passed
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||
frame_sample_mode (`str`, *optional*):
|
||||
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
||||
num_frames (`int`, *optional*):
|
||||
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||
max_fps (`int`, *optional*):
|
||||
Maximum frames per second to sample.
|
||||
sampling_fps (`int`, *optional*):
|
||||
Sampling frames per second. Defaults to `self.sampling_fps`.
|
||||
Used when `frame_sample_mode` is `"fps"`.
|
||||
"""
|
||||
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
||||
num_frames = num_frames or self.num_frames
|
||||
sampling_fps = sampling_fps or self.sampling_fps
|
||||
|
||||
total_num_frames = metadata.total_num_frames
|
||||
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
||||
duration = total_num_frames / metadata.fps
|
||||
if total_num_frames <= 2:
|
||||
return np.arange(total_num_frames).astype(int)
|
||||
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
||||
# uniform fallback
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
return indices
|
||||
else:
|
||||
float_indices = np.arange(
|
||||
0.0,
|
||||
stop=total_num_frames - 1,
|
||||
step=float(metadata.fps / max_fps),
|
||||
)
|
||||
if np.round(float_indices[-1]) != total_num_frames - 1:
|
||||
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
|
||||
indices = np.round(float_indices).astype(int)
|
||||
assert indices[-1] < total_num_frames
|
||||
assert len(float_indices) <= num_frames
|
||||
return indices
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
return indices
|
||||
elif frame_sample_mode == "fps":
|
||||
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
||||
selected_target_fps = get_target_fps(
|
||||
metadata.fps,
|
||||
num_frames,
|
||||
total_num_frames,
|
||||
frame_sample_mode,
|
||||
candidate_target_fps,
|
||||
)
|
||||
_, indices = get_frame_times_and_chosen_fps(
|
||||
selected_target_fps,
|
||||
total_num_frames,
|
||||
num_frames,
|
||||
metadata.fps,
|
||||
)
|
||||
return indices
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
|
||||
"""
|
||||
Convert a single or a list of urls into the corresponding `np.array` objects.
|
||||
|
||||
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||
returned.
|
||||
"""
|
||||
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
|
||||
raise ImportError(
|
||||
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
|
||||
)
|
||||
|
||||
if is_decord_available():
|
||||
backend = "decord"
|
||||
elif is_torchcodec_available():
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `torchcodec`."
|
||||
)
|
||||
backend = "torchcodec"
|
||||
else:
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `PyAV`."
|
||||
)
|
||||
backend = "pyav"
|
||||
|
||||
if isinstance(video_url_or_urls, list):
|
||||
return list(
|
||||
zip(
|
||||
*[
|
||||
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
|
||||
for x in video_url_or_urls
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
|
||||
|
||||
def _decode_and_sample_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
video_metadata: VideoMetadata | dict,
|
||||
do_sample_frames: bool | None = None,
|
||||
sample_indices_fn: Callable | None = None,
|
||||
sample_timestamps_fn: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Decode input videos and sample frames if needed.
|
||||
"""
|
||||
videos = make_batched_videos(videos)
|
||||
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
|
||||
|
||||
# Framed-based sampling if an array video is passed
|
||||
# Otherwise, time-based sampling with decoding
|
||||
if is_valid_video(videos[0]) and do_sample_frames:
|
||||
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
||||
sampled_videos = []
|
||||
sampled_metadata = []
|
||||
for video, metadata in zip(videos, video_metadata):
|
||||
indices = sample_indices_fn(metadata=metadata)
|
||||
metadata.frames_indices = indices
|
||||
sampled_videos.append(video[indices])
|
||||
sampled_metadata.append(metadata)
|
||||
videos = sampled_videos
|
||||
video_metadata = sampled_metadata
|
||||
elif not is_valid_video(videos[0]):
|
||||
if sample_indices_fn is None:
|
||||
logger.warning(
|
||||
"do_sample_frames is False, but video array is not provided: "
|
||||
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
|
||||
)
|
||||
if isinstance(videos[0], list):
|
||||
raise ValueError("A list of images is not supported for video input!")
|
||||
else:
|
||||
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
|
||||
|
||||
return videos, video_metadata
|
||||
|
||||
def _prepare_input_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs,
|
||||
) -> list[np.ndarray]:
|
||||
processed_videos = [to_numpy(video) for video in videos]
|
||||
return processed_videos
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
||||
)
|
||||
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
do_sample_frames = kwargs.pop("do_sample_frames")
|
||||
video_metadata = kwargs.pop("video_metadata")
|
||||
|
||||
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
||||
sample_timestamps_fn = partial(self.sample_times, **kwargs)
|
||||
videos, video_metadata = self._decode_and_sample_videos(
|
||||
videos,
|
||||
video_metadata=video_metadata,
|
||||
do_sample_frames=do_sample_frames,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
sample_timestamps_fn=sample_timestamps_fn,
|
||||
)
|
||||
videos = self._prepare_input_videos(videos=videos)
|
||||
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
return_metadata = kwargs.pop("return_metadata")
|
||||
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
||||
if return_metadata:
|
||||
preprocessed_videos["video_metadata"] = video_metadata
|
||||
return preprocessed_videos
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list[np.ndarray],
|
||||
size: SizeDict | None = None,
|
||||
resample: PILImageResampling | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_size: list[int] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess a video for the model.
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video to preprocess.
|
||||
size (`SizeDict`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||
The pooling size of the vision adapter.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
A `BatchFeature` containing the following keys:
|
||||
- `pixel_values_videos`: The preprocessed videos.
|
||||
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
|
||||
- `video_grids`: The video grids.
|
||||
"""
|
||||
if size.height is None or size.width is None:
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
base_image_input_size = [size.height, size.width]
|
||||
|
||||
resample = resample or self.resample
|
||||
image_mean = image_mean or self.image_mean
|
||||
image_std = image_std or self.image_std
|
||||
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||
|
||||
patch_size = patch_size or self.patch_size
|
||||
pooling_size = pooling_size or self.pooling_size
|
||||
|
||||
image_pooling_h, image_pooling_w = pooling_size
|
||||
|
||||
batch_grids = []
|
||||
batch_crops = []
|
||||
batch_pooled_patches_idx = []
|
||||
|
||||
for video in videos:
|
||||
all_crops = []
|
||||
pooled_patches_idx = []
|
||||
|
||||
for frame in video:
|
||||
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||
frame,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
patch_size,
|
||||
image_pooling_w,
|
||||
image_pooling_h,
|
||||
)
|
||||
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
|
||||
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
|
||||
pooled_patches_idx.append(pooled_idx_with_offset)
|
||||
all_crops.append(crops)
|
||||
|
||||
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
|
||||
all_crops = np.concatenate(all_crops, 0)
|
||||
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
|
||||
|
||||
batch_grids.append(video_grid)
|
||||
batch_crops.append(all_crops)
|
||||
batch_pooled_patches_idx.append(pooled_patches_idx)
|
||||
|
||||
video_grids = np.stack(batch_grids, 0)
|
||||
pixel_values_videos = np.concatenate(batch_crops, 0)
|
||||
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
|
||||
data = dict(
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
video_token_pooling=video_token_pooling,
|
||||
video_grids=video_grids,
|
||||
)
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
MolmoAct2VideoProcessor.register_for_auto_class()
|
||||
1551
src/lerobot/policies/molmoact2/modeling_molmoact2.py
Normal file
1551
src/lerobot/policies/molmoact2/modeling_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
1083
src/lerobot/policies/molmoact2/processor_molmoact2.py
Normal file
1083
src/lerobot/policies/molmoact2/processor_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_vla_jepa_README.md
|
||||
23
src/lerobot/policies/vla_jepa/__init__.py
Normal file
23
src/lerobot/policies/vla_jepa/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
337
src/lerobot/policies/vla_jepa/action_head.py
Normal file
337
src/lerobot/policies/vla_jepa/action_head.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
154
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
154
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
gripper_dim: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
629
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
629
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = config.jepa_tubelet_size
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
155
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
155
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes a gripper dimension after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
|
||||
Only applied when action has more dimensions than gripper_dim.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(
|
||||
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(
|
||||
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
117
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
117
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
418
src/lerobot/policies/vla_jepa/world_model.py
Normal file
418
src/lerobot/policies/vla_jepa/world_model.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -21,11 +21,13 @@ from .factory import (
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
"TOPRewardConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
@@ -37,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm".
|
||||
"sarm", "topreward".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
@@ -53,6 +54,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "topreward":
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -69,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm".
|
||||
"reward_classifier", "sarm", "topreward".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -82,6 +87,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -162,6 +169,14 @@ def make_reward_pre_post_processors(
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, TOPRewardConfig):
|
||||
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
return make_topreward_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
19
src/lerobot/rewards/topreward/__init__.py
Normal file
19
src/lerobot/rewards/topreward/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_topreward import TOPRewardConfig
|
||||
from .modeling_topreward import TOPRewardModel
|
||||
from .processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
__all__ = ["TOPRewardConfig", "TOPRewardModel", "make_topreward_pre_post_processors"]
|
||||
353
src/lerobot/rewards/topreward/compute_rabc_weights.py
Normal file
353
src/lerobot/rewards/topreward/compute_rabc_weights.py
Normal file
@@ -0,0 +1,353 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame TOPReward progress curves for a LeRobot dataset.
|
||||
|
||||
For each episode, scores trajectory prefixes of increasing length using
|
||||
the TOPReward reward model, min-max normalises the raw log-prob rewards per episode,
|
||||
and writes a parquet file with one row per frame.
|
||||
|
||||
The parquet uses the same schema as SARM's :mod:`lerobot.rewards.sarm.compute_rabc_weights`.
|
||||
|
||||
Usage:
|
||||
# Sparse-dense mode (15 anchors per episode, matches upstream)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-samples 15
|
||||
|
||||
# Use a different VLM backbone
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPRewardEncoderProcessorStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "topreward_progress.parquet"
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def normalize_rewards(rewards: list[float] | np.ndarray) -> np.ndarray:
|
||||
"""Min-max normalise raw log-prob rewards into ``[0, 1]``."""
|
||||
rewards_arr = np.asarray(rewards, dtype=np.float64)
|
||||
if rewards_arr.size == 0:
|
||||
return rewards_arr.astype(np.float32)
|
||||
if rewards_arr.size == 1:
|
||||
return np.array([1.0], dtype=np.float32)
|
||||
r_min, r_max = rewards_arr.min(), rewards_arr.max()
|
||||
if r_max == r_min:
|
||||
return np.ones_like(rewards_arr, dtype=np.float32)
|
||||
return ((rewards_arr - r_min) / (r_max - r_min)).astype(np.float32)
|
||||
|
||||
|
||||
def compute_instruction_rewards_for_prefixes(
|
||||
model: TOPRewardModel,
|
||||
encoder: TOPRewardEncoderProcessorStep,
|
||||
dataset: LeRobotDataset,
|
||||
ep_start: int,
|
||||
num_frames: int,
|
||||
task: str,
|
||||
image_key: str,
|
||||
num_samples: int | None,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""Score an episode via prefix sweep and return a per-frame normalised curve."""
|
||||
if num_samples is None or num_samples >= num_frames:
|
||||
prefix_lengths = np.arange(1, num_frames + 1, dtype=np.int64)
|
||||
else:
|
||||
prefix_lengths = np.unique(np.linspace(1, num_frames, num_samples).round().astype(np.int64))
|
||||
|
||||
episode_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
rewards: list[float] = []
|
||||
for length in prefix_lengths:
|
||||
frames = episode_frames[: int(length)].unsqueeze(0) # (1, T, C, H, W)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {image_key: frames},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
|
||||
}
|
||||
encoded = encoder(transition)
|
||||
obs = encoded[TransitionKey.OBSERVATION]
|
||||
batch = {
|
||||
key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in obs.items()
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
reward = model.compute_reward(batch)
|
||||
rewards.append(float(reward.item()))
|
||||
|
||||
normalized_rewards = normalize_rewards(rewards)
|
||||
|
||||
if prefix_lengths.shape[0] == num_frames:
|
||||
return normalized_rewards
|
||||
|
||||
return np.interp(
|
||||
np.arange(1, num_frames + 1, dtype=np.float64),
|
||||
prefix_lengths.astype(np.float64),
|
||||
normalized_rewards.astype(np.float64),
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def compute_topreward_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str | None = None,
|
||||
vlm_name: str | None = None,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
num_samples: int | None = None,
|
||||
fps: float | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
) -> Path:
|
||||
"""Run TOPReward over a dataset and write per-frame progress."""
|
||||
if reward_model_path is not None:
|
||||
logging.info(f"Loading TOPReward config from: {reward_model_path}")
|
||||
model = TOPRewardModel.from_pretrained(reward_model_path)
|
||||
config = model.config
|
||||
config.device = device
|
||||
if vlm_name is not None and vlm_name != config.vlm_name:
|
||||
logging.info(f"Overriding vlm_name from config: {config.vlm_name} -> {vlm_name}")
|
||||
config.vlm_name = vlm_name
|
||||
model = TOPRewardModel(config)
|
||||
else:
|
||||
config_kwargs: dict[str, Any] = {"device": device}
|
||||
if vlm_name is not None:
|
||||
config_kwargs["vlm_name"] = vlm_name
|
||||
if fps is not None:
|
||||
config_kwargs["fps"] = fps
|
||||
config = TOPRewardConfig(**config_kwargs)
|
||||
logging.info(f"Constructing TOPReward with VLM: {config.vlm_name}")
|
||||
model = TOPRewardModel(config)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
encoder = TOPRewardEncoderProcessorStep(
|
||||
vlm_name=config.vlm_name,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=None, # no tail-crop: we control prefix length explicitly
|
||||
fps=config.fps,
|
||||
prompt_prefix=config.prompt_prefix,
|
||||
prompt_suffix_template=config.prompt_suffix_template,
|
||||
add_chat_template=config.add_chat_template,
|
||||
max_length=config.max_input_length,
|
||||
)
|
||||
|
||||
image_key = config.image_key
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
|
||||
logging.info(f"Processing {len(episode_indices)} episode(s)")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(episode_indices, desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
per_frame = compute_instruction_rewards_for_prefixes(
|
||||
model=model,
|
||||
encoder=encoder,
|
||||
dataset=dataset,
|
||||
ep_start=ep_start,
|
||||
num_frames=num_frames,
|
||||
task=task,
|
||||
image_key=image_key,
|
||||
num_samples=num_samples,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
schema_metadata: dict[bytes, bytes] = {b"vlm_name": config.vlm_name.encode()}
|
||||
if reward_model_path is not None:
|
||||
schema_metadata[b"reward_model_path"] = reward_model_path.encode()
|
||||
table = table.replace_schema_metadata(schema_metadata)
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame TOPReward progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Sparse-dense mode (matches upstream TOPReward num_samples=15)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-samples 15
|
||||
|
||||
# Use a smaller VLM
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path", type=str, default=None, help="Optional TOPReward LeRobot config."
|
||||
)
|
||||
parser.add_argument("--vlm-name", type=str, default=None, help="Override the VLM backbone (HF Hub id).")
|
||||
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Anchor prefix samples per episode. None = dense. 15 matches upstream.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Process only these episode indices (e.g. --episodes 0 or --episodes 0 5 10).",
|
||||
)
|
||||
parser.add_argument("--fps", type=float, default=None, help="Override TOPRewardConfig.fps.")
|
||||
parser.add_argument(
|
||||
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
output_path = compute_topreward_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=args.reward_model_path,
|
||||
vlm_name=args.vlm_name,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
num_samples=args.num_samples,
|
||||
fps=args.fps,
|
||||
episodes=args.episodes,
|
||||
)
|
||||
|
||||
print(f"\nTOPReward progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
src/lerobot/rewards/topreward/configuration_topreward.py
Normal file
146
src/lerobot/rewards/topreward/configuration_topreward.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Default prompt scaffolding from the upstream TOPReward paper / reference
|
||||
# implementation (``QwenClient.compute_instruction_reward``). The prompt
|
||||
# scores the terminal ``True`` token in ``f"{instruction} ... True"``
|
||||
# given the video.
|
||||
DEFAULT_PROMPT_PREFIX = (
|
||||
"The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
)
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE = (
|
||||
"{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("topreward")
|
||||
@dataclass
|
||||
class TOPRewardConfig(RewardModelConfig):
|
||||
"""Configuration for the TOPReward zero-shot reward model.
|
||||
|
||||
TOPReward is **zero-shot**: it has no learnable parameters of its own.
|
||||
The "model" is a generic vision-language model (default
|
||||
``Qwen/Qwen3-VL-8B-Instruct``) used with a fixed prompt to extract
|
||||
token log-probabilities as a reward signal. There is therefore no
|
||||
fine-tuned checkpoint to host: ``pretrained_path`` is unused at
|
||||
runtime — the model identity is :attr:`vlm_name` (an HF Hub id).
|
||||
|
||||
Args:
|
||||
vlm_name: Hugging Face Hub id of the underlying VLM. Must be a
|
||||
Qwen3-VL family model (the only client implemented in this
|
||||
LeRobot port).
|
||||
torch_dtype: Torch dtype name passed to the VLM loader
|
||||
(``"auto"``, ``"bfloat16"``, ``"float16"``, ...).
|
||||
attn_implementation: ``transformers`` attention implementation
|
||||
(e.g. ``"flash_attention_2"``, ``"sdpa"``). Defaults to
|
||||
``None`` so the upstream picks the best available.
|
||||
image_key: Observation key that holds the trajectory frames.
|
||||
task_key: Complementary-data key that holds the task instruction.
|
||||
default_task: Fallback instruction when ``task_key`` is absent.
|
||||
max_frames: Cap on the number of frames fed to the VLM per
|
||||
sample. ``None`` = use all frames.
|
||||
fps: Frames-per-second metadata for the Qwen video processor.
|
||||
prompt_prefix: Text shown to the VLM right after the video and
|
||||
before the suffix template.
|
||||
prompt_suffix_template: Suffix appended after ``prompt_prefix``.
|
||||
Must contain ``{instruction}``; the VLM scores the
|
||||
log-likelihood of the tokens that follow the prefix.
|
||||
add_chat_template: If ``True``, wrap the full prompt with the
|
||||
tokenizer's chat template before tokenisation (matches
|
||||
upstream ``add_chat_template=True``).
|
||||
success_threshold: Optional log-prob threshold. If finite,
|
||||
:meth:`TOPRewardModel.compute_reward` returns
|
||||
``(reward > success_threshold).float()`` instead of the raw
|
||||
log-prob.
|
||||
max_input_length: Hard limit on the total tokenized input length;
|
||||
samples that exceed it raise a ``ValueError``.
|
||||
"""
|
||||
|
||||
# Path to a local LeRobot dir or HF repo that holds a ``config.json``
|
||||
# snapshot of this TOPRewardConfig. The VLM weights themselves are
|
||||
# always identified by ``vlm_name``.
|
||||
pretrained_path: str | None = None
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
torch_dtype: str = "auto"
|
||||
attn_implementation: str | None = None
|
||||
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
|
||||
success_threshold: float = float("-inf")
|
||||
max_input_length: int = 32768
|
||||
|
||||
license: str | None = "mit" # matches upstream TOPReward
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be > 0, got {self.fps}")
|
||||
if "{instruction}" not in self.prompt_suffix_template:
|
||||
raise ValueError(
|
||||
"prompt_suffix_template must contain `{instruction}` so the model "
|
||||
"scores the log-likelihood of the task suffix."
|
||||
)
|
||||
if self.max_input_length <= 0:
|
||||
raise ValueError(f"max_input_length must be > 0, got {self.max_input_length}")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("reward", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"TOPReward requires image input feature {self.image_key!r}")
|
||||
238
src/lerobot/rewards/topreward/modeling_topreward.py
Normal file
238
src/lerobot/rewards/topreward/modeling_topreward.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright 2026 Shirui Chen, Cole Harrison, Ying-Chun Lee, Angela Jin Yang,
|
||||
# Zhongzheng Ren, Lillian J. Ratliff, Jiafei Duan, Dieter Fox, Ranjay Krishna
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics.
|
||||
|
||||
Paper: https://arxiv.org/abs/2602.19313
|
||||
Project: https://topreward.github.io/webpage/
|
||||
Original code: https://github.com/TOPReward/TOPReward
|
||||
Backbone: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct (default)
|
||||
|
||||
TOPReward is a **zero-shot** reward model: it has no fine-tuned weights of
|
||||
its own. Given a video trajectory and a task instruction, it asks an
|
||||
off-the-shelf VLM how likely the instruction is, conditioned on the video,
|
||||
and returns that log-likelihood as the reward signal.
|
||||
|
||||
Inference recipe:
|
||||
|
||||
1. The processor builds a chat-style prompt, tokenises it, and emits
|
||||
``input_ids``, ``attention_mask``, vision tensors, and ``labels``.
|
||||
The processor label-masks everything except the terminal answer token with
|
||||
``-100``.
|
||||
2. Forward the full token sequence through the VLM.
|
||||
3. Read the terminal answer token log-probability from the logits as the
|
||||
scalar reward.
|
||||
|
||||
With the default ``prompt_suffix_template``, the only unmasked token is the
|
||||
literal ``"True"`` at the end — the reward is
|
||||
``log P("True" | video + prompt + instruction)``.
|
||||
|
||||
This LeRobot port is **inference-only and not trainable** — :meth:`forward`
|
||||
is intentionally inherited from :class:`PreTrainedRewardModel` and raises
|
||||
``NotImplementedError``, making :attr:`PreTrainedRewardModel.is_trainable`
|
||||
return ``False``.
|
||||
|
||||
Because the VLM weights live on the Hugging Face Hub under their canonical
|
||||
id (``Qwen/Qwen3-VL-8B-Instruct`` etc.) and TOPReward never modifies them,
|
||||
:meth:`_save_pretrained` and :meth:`from_pretrained` are overridden so a
|
||||
TOPReward LeRobot "checkpoint" is a single ``config.json`` (the VLM is
|
||||
re-fetched from the Hub at load time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import cross_entropy
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX, TOPREWARD_INPUT_KEYS
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
Qwen3VLForConditionalGeneration = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="TOPRewardModel")
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype | str:
|
||||
"""Resolve a torch dtype name; ``"auto"`` is passed through verbatim."""
|
||||
if name == "auto":
|
||||
return "auto"
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
class TOPRewardModel(PreTrainedRewardModel):
|
||||
"""TOPReward zero-shot reward model."""
|
||||
|
||||
name = "topreward"
|
||||
config_class = TOPRewardConfig
|
||||
|
||||
def __init__(self, config: TOPRewardConfig) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
model_kwargs: dict[str, Any] = {"dtype": torch_dtype, "trust_remote_code": True}
|
||||
if config.attn_implementation is not None:
|
||||
model_kwargs["attn_implementation"] = config.attn_implementation
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(config.vlm_name, **model_kwargs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Return one log-prob reward per sample in the batch."""
|
||||
inputs: dict[str, Any] = {}
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
batch_key = f"{TOPREWARD_FEATURE_PREFIX}{key}"
|
||||
if batch_key not in batch:
|
||||
raise KeyError(
|
||||
f"TOPReward batch missing `{batch_key}`. Make sure the "
|
||||
"TOPRewardEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
inputs[key] = batch[batch_key]
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
|
||||
labels = inputs.pop("labels")
|
||||
inputs["logits_to_keep"] = 2
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
rewards = -cross_entropy(logits[:, -2, :].float(), labels[:, -1], reduction="none")
|
||||
if np.isfinite(self.config.success_threshold):
|
||||
rewards = (rewards > self.config.success_threshold).float()
|
||||
return rewards.to(self.config.device or "cpu")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save ``config.json`` only."""
|
||||
self.config._save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False, # noqa: ARG003 — accepted for API parity; unused (no safetensors to load)
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Load a TOPReward configuration and instantiate the wrapped VLM."""
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(config, TOPRewardConfig):
|
||||
raise TypeError(
|
||||
f"Expected a TOPRewardConfig, got {type(config).__name__}. Make sure "
|
||||
f"`pretrained_name_or_path={pretrained_name_or_path!r}` points at a "
|
||||
"TOPReward checkpoint."
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
if not os.path.isdir(model_id):
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
instance = cls(config, **kwargs)
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
def push_model_to_hub(self, cfg: TrainPipelineConfig):
|
||||
"""Push the TOPReward ``config.json`` + model card to the Hub."""
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
saved_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.config._save_pretrained(saved_path)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path)
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload TOPReward config and readme",
|
||||
allow_patterns=["*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log", "*.safetensors"],
|
||||
)
|
||||
|
||||
logger.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
305
src/lerobot/rewards/topreward/processor_topreward.py
Normal file
305
src/lerobot/rewards/topreward/processor_topreward.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward pre/post processing pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.topreward.configuration_topreward import (
|
||||
DEFAULT_PROMPT_PREFIX,
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE,
|
||||
TOPRewardConfig,
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
TOPREWARD_FEATURE_PREFIX = f"{OBS_PREFIX}topreward."
|
||||
|
||||
_TRUE_ANSWER = "True"
|
||||
|
||||
TOPREWARD_VLM_INPUT_KEYS = (
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
TOPREWARD_INPUT_KEYS = TOPREWARD_VLM_INPUT_KEYS + ("labels",)
|
||||
|
||||
|
||||
def _prepare_video_batch(video: Tensor, *, max_frames: int | None) -> Tensor:
|
||||
"""Return videos as ``(B, T, C, H, W)`` uint8 tensors for Qwen3-VL."""
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(1)
|
||||
elif video.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected TOPReward frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(video.shape)}"
|
||||
)
|
||||
|
||||
if max_frames is not None:
|
||||
video = video[:, -max_frames:]
|
||||
if video.shape[-1] in (1, 3):
|
||||
video = video.permute(0, 1, 4, 2, 3)
|
||||
elif video.shape[2] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
if video.is_floating_point():
|
||||
video = video * 255.0
|
||||
|
||||
return video.clamp(0, 255).to(torch.uint8).contiguous()
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("TOPReward expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"TOPReward task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="topreward_encoder")
|
||||
class TOPRewardEncoderProcessorStep(ProcessorStep):
|
||||
"""Encode raw frames + task into Qwen-VL tensors for the TOPReward model.
|
||||
|
||||
Loads a :class:`~transformers.AutoProcessor` matching ``vlm_name`` and
|
||||
builds the full chat prompt including the instruction suffix. The
|
||||
resulting ``input_ids``, ``attention_mask``, vision tensors, and
|
||||
``labels`` are written under the ``observation.topreward.*`` namespace
|
||||
so the model can score without re-tokenising.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes ``observation[f"{TOPREWARD_FEATURE_PREFIX}<name>"]`` for the
|
||||
Qwen-VL tensors plus ``labels``.
|
||||
"""
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
max_length: int = 32768
|
||||
|
||||
_processor: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
self._processor = AutoProcessor.from_pretrained(self.vlm_name, trust_remote_code=True)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"TOPReward expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
videos = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
videos = _prepare_video_batch(videos, max_frames=self.max_frames)
|
||||
|
||||
batch_size = videos.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
encoded = self._encode_batch(videos, tasks, batch_size)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for key, value in encoded.items():
|
||||
new_observation[f"{TOPREWARD_FEATURE_PREFIX}{key}"] = value
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def _encode_batch(self, videos: Tensor, tasks: list[str], batch_size) -> dict[str, Any]:
|
||||
"""Tokenise a batch of (frames, task) pairs into Qwen-VL tensors.
|
||||
|
||||
The loop only builds per-sample chat strings. Tokenisation, padding,
|
||||
video preprocessing, and label construction are batched.
|
||||
"""
|
||||
|
||||
texts: list[str] = []
|
||||
video_metadata = [
|
||||
{
|
||||
"total_num_frames": int(videos.shape[1]),
|
||||
"fps": float(self.fps),
|
||||
"frames_indices": list(range(int(videos.shape[1]))),
|
||||
}
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
eos_token = self._processor.tokenizer.eos_token
|
||||
|
||||
for i in range(batch_size):
|
||||
instruction_suffix = self.prompt_suffix_template.format(instruction=tasks[i])
|
||||
if self.add_chat_template:
|
||||
suffix_for_template = instruction_suffix.removesuffix(_TRUE_ANSWER).rstrip()
|
||||
templated_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": videos[i], "fps": self.fps},
|
||||
{"type": "text", "text": f"{self.prompt_prefix}{suffix_for_template}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self._processor.apply_chat_template(
|
||||
templated_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
full_text = f"{prompt_chat}{_TRUE_ANSWER}"
|
||||
else:
|
||||
user_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": videos[i], "fps": self.fps},
|
||||
{"type": "text", "text": self.prompt_prefix},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self._processor.apply_chat_template(
|
||||
user_messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
if eos_token is not None:
|
||||
prompt_chat = prompt_chat.split(eos_token)[0]
|
||||
full_text = f"{prompt_chat}{instruction_suffix}"
|
||||
|
||||
texts.append(full_text)
|
||||
|
||||
result = self._processor(
|
||||
text=texts,
|
||||
videos=videos,
|
||||
video_metadata=video_metadata,
|
||||
do_sample_frames=False,
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = result["input_ids"]
|
||||
|
||||
if input_ids.shape[-1] > self.max_length:
|
||||
raise ValueError(
|
||||
f"TOPReward input length {input_ids.shape[-1]} exceeds max_length "
|
||||
f"{self.max_length}; lower `max_frames` or raise `max_length`."
|
||||
)
|
||||
|
||||
labels = torch.full_like(input_ids, -100)
|
||||
labels[:, -1] = input_ids[:, -1]
|
||||
result["labels"] = labels
|
||||
return result
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"vlm_name": self.vlm_name,
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
"fps": self.fps,
|
||||
"prompt_prefix": self.prompt_prefix,
|
||||
"prompt_suffix_template": self.prompt_suffix_template,
|
||||
"add_chat_template": self.add_chat_template,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
|
||||
def make_topreward_pre_post_processors(
|
||||
config: TOPRewardConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs TOPReward's
|
||||
encoder (which tokenises the full prompt and emits ``labels``), and
|
||||
moves everything to the configured device. The postprocessor is
|
||||
the identity since TOPReward outputs a single reward tensor.
|
||||
"""
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TOPRewardEncoderProcessorStep(
|
||||
vlm_name=config.vlm_name,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
fps=config.fps,
|
||||
prompt_prefix=config.prompt_prefix,
|
||||
prompt_suffix_template=config.prompt_suffix_template,
|
||||
add_chat_template=config.add_chat_template,
|
||||
max_length=config.max_input_length,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -13,6 +13,8 @@
|
||||
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
|
||||
{% elif model_name == "sarm" %}
|
||||
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
|
||||
{% elif model_name == "topreward" %}
|
||||
TOPReward is a **zero-shot** reward model that extracts token log-probabilities from an off-the-shelf vision-language model (default Qwen3-VL) as a reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood of the instruction being true, with no fine-tuning required.
|
||||
{% else %}
|
||||
_Reward model type not recognized — please update this template._
|
||||
{% endif %}
|
||||
|
||||
1397
tests/policies/molmoact2/test_molmoact2.py
Normal file
1397
tests/policies/molmoact2/test_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
273
tests/policies/vla_jepa/conftest.py
Normal file
273
tests/policies/vla_jepa/conftest.py
Normal file
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BATCH_SIZE = 2
|
||||
ACTION_DIM = 3
|
||||
STATE_DIM = 4
|
||||
IMAGE_SIZE = 8
|
||||
ACTION_HORIZON = 4
|
||||
N_ACTION_STEPS = 2
|
||||
NUM_VIDEO_FRAMES = 3
|
||||
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
||||
|
||||
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def set_seed_all(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def make_config(
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> VLAJEPAConfig:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
},
|
||||
device="cpu",
|
||||
chunk_size=action_horizon,
|
||||
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
||||
action_dim=action_dim,
|
||||
state_dim=state_dim,
|
||||
num_video_frames=num_video_frames,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=1,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def make_train_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
||||
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def make_inference_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
state_dim: int = STATE_DIM,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLanguageLayer(nn.Module):
|
||||
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
|
||||
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
|
||||
return (hidden,)
|
||||
|
||||
|
||||
class _FakeLanguageModel(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
|
||||
self.layers[-1](hidden)
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
class _FakeQwenInnerModel(nn.Module):
|
||||
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.language_model = _FakeLanguageModel(hidden_size)
|
||||
|
||||
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
|
||||
class _FakeQwenBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
self.config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||
)
|
||||
self.model = _FakeQwenInnerModel(hidden_size)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self.config.hidden_size
|
||||
values = torch.arange(
|
||||
batch_size * seq_len * hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=torch.float32,
|
||||
).view(batch_size, seq_len, hidden_size)
|
||||
hidden = values / values.numel() + self.weight
|
||||
self.model(input_ids) # call through so the forward hook on layers[-1] fires
|
||||
return SimpleNamespace(hidden_states=[hidden])
|
||||
|
||||
|
||||
class _FakeQwenInterface(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||
return action_tokens, action_token_ids, 2000
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, Tensor]:
|
||||
batch_size = len(images)
|
||||
del images, instructions, action_prompt, embodied_prompt
|
||||
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
token_ids = (
|
||||
[10]
|
||||
+ list(range(1000, 1000 + action_count))
|
||||
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||
+ [11]
|
||||
)
|
||||
return {
|
||||
"input_ids": torch.tensor(
|
||||
[token_ids] * batch_size,
|
||||
device=self.model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
|
||||
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
|
||||
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||
hidden_size = self.config.hidden_size
|
||||
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
if isinstance(videos, list):
|
||||
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||
else:
|
||||
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||
return {"pixel_values_videos": pixel_values}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||
|
||||
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoModel,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoVideoProcessor,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||
)
|
||||
157
tests/policies/vla_jepa/test_action_head.py
Normal file
157
tests/policies/vla_jepa/test_action_head.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
from conftest import (
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
set_seed_all,
|
||||
) # noqa: E402
|
||||
|
||||
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
|
||||
VLAJEPAActionHead,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLAJEPAActionHead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4), # default test dims
|
||||
(7, 0, 16), # no proprioceptive state, production-like action space
|
||||
(6, 8, 8), # medium dims
|
||||
],
|
||||
)
|
||||
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
|
||||
assert t.shape == (200,)
|
||||
assert torch.isfinite(t).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
timesteps = torch.randint(0, 100, (2,))
|
||||
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
out_with = head._build_inputs(conditioning, actions, state, timesteps)
|
||||
out_none = head._build_inputs(conditioning, actions, None, timesteps)
|
||||
|
||||
assert out_with.ndim == 3 and out_none.ndim == 3
|
||||
if state_dim > 0:
|
||||
assert out_with.shape[1] > out_none.shape[1]
|
||||
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
def test_action_head_forward_gradient_flows() -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
pred = head.predict_action(conditioning, state)
|
||||
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||
assert torch.isfinite(pred).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# action_is_pad masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_action_head_loss_fully_padded_is_zero() -> None:
|
||||
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss = head.forward(conditioning, actions, state, action_is_pad)
|
||||
assert loss.item() == 0.0
|
||||
|
||||
|
||||
def test_action_head_loss_none_matches_no_padding() -> None:
|
||||
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
set_seed_all(0)
|
||||
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
|
||||
|
||||
set_seed_all(0)
|
||||
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
|
||||
|
||||
assert torch.isclose(loss_none, loss_zeros)
|
||||
57
tests/policies/vla_jepa/test_configuration.py
Normal file
57
tests/policies/vla_jepa/test_configuration.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def test_delta_indices() -> None:
|
||||
config = make_config()
|
||||
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
||||
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
||||
|
||||
|
||||
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
||||
with pytest.raises(ValueError, match="n_action_steps"):
|
||||
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
|
||||
|
||||
|
||||
def test_too_few_video_frames_raises() -> None:
|
||||
with pytest.raises(ValueError, match="video_horizon"):
|
||||
VLAJEPAConfig(
|
||||
chunk_size=16,
|
||||
n_action_steps=16,
|
||||
num_video_frames=2,
|
||||
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
||||
)
|
||||
|
||||
|
||||
def test_validate_features_no_image_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_no_action_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
},
|
||||
output_features={},
|
||||
)
|
||||
with pytest.raises(ValueError, match="action output feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_sets_action_dim_from_feature() -> None:
|
||||
config = make_config(action_dim=6, state_dim=10)
|
||||
assert config.action_dim == 6
|
||||
assert config.state_dim == 10
|
||||
598
tests/policies/vla_jepa/test_vla_jepa.py
Normal file
598
tests/policies/vla_jepa/test_vla_jepa.py
Normal file
@@ -0,0 +1,598 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||
)
|
||||
|
||||
from conftest import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
EXPECTED_ACTION_CHUNK_SHAPE,
|
||||
EXPECTED_SELECT_ACTION_SHAPE,
|
||||
IMAGE_SIZE,
|
||||
N_ACTION_STEPS,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
make_inference_batch,
|
||||
make_train_batch,
|
||||
set_seed_all,
|
||||
)
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
|
||||
from lerobot.utils.constants import ACTION # noqa: E402
|
||||
|
||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||
|
||||
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core training / inference tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
|
||||
batch = make_train_batch()
|
||||
batch_before = deepcopy(batch)
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
assert logs["action_loss"] > 0
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
|
||||
# Batch must not be mutated.
|
||||
assert set(batch) == set(batch_before)
|
||||
for key, value in batch.items():
|
||||
if isinstance(value, Tensor):
|
||||
assert torch.equal(value, batch_before[key])
|
||||
else:
|
||||
assert value == batch_before[key]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_training_forward_various_dims(
|
||||
patch_vla_jepa_external_models: None,
|
||||
action_dim: int,
|
||||
state_dim: int,
|
||||
action_horizon: int,
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.train()
|
||||
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
loss, _ = policy.forward(batch)
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert chunk.device.type == "cpu"
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
a1 = policy.select_action(batch)
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
|
||||
def test_action_generation_various_dims(
|
||||
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.eval()
|
||||
batch = make_inference_batch(state_dim=state_dim)
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert chunk.shape[-1] == action_dim
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
set_seed_all(123)
|
||||
actions_1 = policy.predict_action_chunk(batch)
|
||||
set_seed_all(123)
|
||||
actions_2 = policy.predict_action_chunk(batch)
|
||||
|
||||
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
for seed in [0, 42, 123]:
|
||||
set_seed_all(seed)
|
||||
chunk = policy.predict_action_chunk(make_inference_batch())
|
||||
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action queue behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
# First call fills the queue (n_action_steps items) and pops one.
|
||||
a1 = policy.select_action(batch)
|
||||
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
|
||||
|
||||
# Second call pops from the existing queue without calling predict_action_chunk.
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
policy.select_action(make_inference_batch())
|
||||
assert len(policy._queues[ACTION]) > 0
|
||||
|
||||
policy.reset()
|
||||
assert len(policy._queues[ACTION]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||
from PIL import Image
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
examples = policy._prepare_model_inputs(make_train_batch())
|
||||
|
||||
assert len(examples) == BATCH_SIZE
|
||||
for ex in examples:
|
||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||
assert ex["state"].shape == (1, STATE_DIM)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||
assert "action" not in ex
|
||||
assert "image" in ex and "video" in ex and "lang" in ex
|
||||
|
||||
|
||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch["task"]
|
||||
examples = policy._prepare_model_inputs(batch)
|
||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
batch["task"] = "open the drawer"
|
||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch[OBS_STATE]
|
||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretrained checkpoint
|
||||
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||
return batch
|
||||
|
||||
|
||||
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||
return batch
|
||||
|
||||
|
||||
_CP_ROOT = "lerobot"
|
||||
|
||||
# Each tuple: (repo_id, enable_world_model)
|
||||
_HUB_VARIANTS = [
|
||||
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||
]
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
assert policy.config.enable_world_model == enable_world_model
|
||||
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
policy.train()
|
||||
|
||||
batch = _make_hub_train_batch(policy)
|
||||
loss, logs = policy.forward(batch)
|
||||
assert torch.isfinite(loss)
|
||||
assert "action_loss" in logs
|
||||
if enable_world_model:
|
||||
assert "wm_loss" in logs
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
qwen_params = list(policy.model.qwen.parameters())
|
||||
assert all(not p.requires_grad for p in qwen_params)
|
||||
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
assert policy.model.video_encoder is None
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_libero_inference_shape() -> None:
|
||||
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||
policy.eval()
|
||||
batch = _make_hub_inference_batch(policy)
|
||||
action = policy.select_action(batch)
|
||||
assert action.shape[-1] == policy.config.action_dim
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postprocessor unnormalization tests
|
||||
#
|
||||
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
|
||||
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
|
||||
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
return {
|
||||
ACTION: {
|
||||
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
|
||||
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
|
||||
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
|
||||
rng = np.random.default_rng(7)
|
||||
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
actions_tensor = torch.from_numpy(actions_np)
|
||||
transition = policy_action_to_transition(actions_tensor)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
|
||||
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
|
||||
# Deliberately out-of-range inputs
|
||||
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
|
||||
clipped = np.clip(actions_np, -1.0, 1.0)
|
||||
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
clip_step = ClipActionsProcessorStep()
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
transition = policy_action_to_transition(torch.from_numpy(actions_np))
|
||||
transition = clip_step(transition)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_applied_after_predict_action_chunk(
|
||||
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
|
||||
|
||||
Verifies the split: predict_action_chunk returns normalized actions, and calling the
|
||||
postprocessor on them produces the correctly unnormalized result.
|
||||
"""
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||
|
||||
cfg = make_config()
|
||||
cfg.clip_normalized_actions = False
|
||||
cfg.binarize_gripper_action = False
|
||||
policy = VLAJEPAPolicy(cfg)
|
||||
policy.eval()
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||
|
||||
batch = make_inference_batch()
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
|
||||
# predict_action_chunk returns raw (normalized) actions
|
||||
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
|
||||
"predict_action_chunk should return raw actions without unnormalization applied."
|
||||
)
|
||||
|
||||
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
|
||||
unnormed = postprocessor(chunk)
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
|
||||
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World-model view adjustment (padding / trimming) tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests
|
||||
|
||||
|
||||
def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig:
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
**{
|
||||
f"{OBS_IMAGES}.cam{i}": PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
)
|
||||
for i in range(num_views)
|
||||
},
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
device="cpu",
|
||||
chunk_size=ACTION_HORIZON,
|
||||
n_action_steps=N_ACTION_STEPS,
|
||||
action_dim=ACTION_DIM,
|
||||
state_dim=STATE_DIM,
|
||||
num_video_frames=_MULTIVIEW_NUM_FRAMES,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=jepa_tubelet_size,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict:
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
for i in range(num_views)
|
||||
}
|
||||
batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM)
|
||||
batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM)
|
||||
batch["task"] = ["pick up the cube"] * batch_size
|
||||
return batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_views",
|
||||
[
|
||||
1, # fewer views than jepa_tubelet_size → first view duplicated
|
||||
2, # exact match → unchanged
|
||||
3, # more views than jepa_tubelet_size → trimmed to first two
|
||||
],
|
||||
)
|
||||
def test_training_forward_world_model_view_adjustment(
|
||||
patch_vla_jepa_external_models: None,
|
||||
num_views: int,
|
||||
) -> None:
|
||||
"""World-model view padding/trimming must not break the training forward pass."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views))
|
||||
assert torch.isfinite(loss)
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
|
||||
def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None:
|
||||
"""With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
|
||||
captured_videos: list = []
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=1))
|
||||
|
||||
# reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …)
|
||||
assert len(captured_videos) == BATCH_SIZE * 2
|
||||
for i in range(BATCH_SIZE):
|
||||
np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1])
|
||||
|
||||
|
||||
def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None:
|
||||
"""With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
|
||||
captured_videos: list = []
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=3))
|
||||
|
||||
# Only B*2 items must reach the encoder, not B*3.
|
||||
assert len(captured_videos) == BATCH_SIZE * 2
|
||||
60
tests/policies/vla_jepa/test_world_model.py
Normal file
60
tests/policies/vla_jepa/test_world_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.world_model import (
|
||||
ActionConditionedVideoPredictor,
|
||||
)
|
||||
|
||||
_ACTION_EMBED_DIM = 8
|
||||
|
||||
|
||||
def _make_predictor(
|
||||
embed_dim: int = 8,
|
||||
action_embed_dim: int = _ACTION_EMBED_DIM,
|
||||
predictor_embed_dim: int = 24,
|
||||
num_action_tokens: int = 2,
|
||||
tokens_per_frame: int = 1,
|
||||
) -> ActionConditionedVideoPredictor:
|
||||
return ActionConditionedVideoPredictor(
|
||||
num_frames=1,
|
||||
img_size=(1, tokens_per_frame),
|
||||
patch_size=1,
|
||||
tubelet_size=1,
|
||||
embed_dim=embed_dim,
|
||||
action_embed_dim=action_embed_dim,
|
||||
predictor_embed_dim=predictor_embed_dim,
|
||||
depth=1,
|
||||
num_heads=2,
|
||||
mlp_ratio=2.0,
|
||||
num_action_tokens_per_step=num_action_tokens,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,num_steps,tokens_per_frame,embed_dim",
|
||||
[
|
||||
(1, 2, 1, 8),
|
||||
(2, 3, 4, 8),
|
||||
(4, 5, 2, 16),
|
||||
],
|
||||
)
|
||||
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||
predictor = _make_predictor(
|
||||
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
|
||||
)
|
||||
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
|
||||
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
|
||||
out = predictor(frame_tokens, action_tokens)
|
||||
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_predictor_step_mismatch_raises() -> None:
|
||||
predictor = _make_predictor(tokens_per_frame=4)
|
||||
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
|
||||
with pytest.raises(RuntimeError):
|
||||
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch
|
||||
296
tests/rewards/test_modeling_topreward.py
Normal file
296
tests/rewards/test_modeling_topreward.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the TOPReward reward model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.factory import get_reward_model_class, make_reward_model_config
|
||||
from lerobot.rewards.topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX, TOPREWARD_INPUT_KEYS
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
class _FakeQwenModel(torch.nn.Module):
|
||||
"""Stand-in for ``Qwen3VLForConditionalGeneration``.
|
||||
|
||||
Returns a ``SimpleNamespace`` with ``logits`` of a controlled shape so
|
||||
the log-prob extraction path in ``compute_reward`` can be exercised
|
||||
without downloading real VLM weights.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._param = torch.nn.Parameter(torch.zeros(1))
|
||||
self._reward_value: float = -1.5
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
def forward( # noqa: ARG002
|
||||
self, input_ids, attention_mask=None, labels=None, logits_to_keep=0, **kwargs
|
||||
):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
vocab_size = 1000
|
||||
logits = torch.zeros(batch_size, seq_len, vocab_size)
|
||||
# Place a controlled log-prob at the target token position so the
|
||||
# model returns a predictable reward value.
|
||||
# The label-masked suffix is the last token.
|
||||
# After the causal-LM shift (logits[:, :-1], labels[:, 1:]) the scored
|
||||
# position is logits[:, -2, :] predicting labels[:, -1].
|
||||
# We set logits so that log_softmax at the target token ≈ _reward_value.
|
||||
for i in range(batch_size):
|
||||
target_idx = int(input_ids[i, -1].item())
|
||||
logits[i, -2, target_idx] = self._reward_value * -10 # high logit -> high log-prob
|
||||
if logits_to_keep:
|
||||
logits = logits[:, -logits_to_keep:, :]
|
||||
return SimpleNamespace(logits=logits)
|
||||
|
||||
|
||||
def _patch_build(monkeypatch) -> None:
|
||||
"""Stub out HF AutoX so TOPReward construction is cheap and offline."""
|
||||
from lerobot.rewards.topreward import modeling_topreward
|
||||
|
||||
monkeypatch.setattr(modeling_topreward, "Qwen3VLForConditionalGeneration", _FakeQwenModel)
|
||||
|
||||
|
||||
def _make_batch(
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
labels: torch.Tensor | None = None,
|
||||
*,
|
||||
omit: str | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Build a ``compute_reward``-ready batch using TOPReward's namespaced keys."""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||
batch: dict[str, torch.Tensor] = {}
|
||||
if labels is not None:
|
||||
batch[f"{TOPREWARD_FEATURE_PREFIX}labels"] = labels
|
||||
batch.update(
|
||||
{
|
||||
f"{TOPREWARD_FEATURE_PREFIX}input_ids": input_ids,
|
||||
f"{TOPREWARD_FEATURE_PREFIX}attention_mask": attention_mask,
|
||||
f"{TOPREWARD_FEATURE_PREFIX}pixel_values_videos": torch.zeros(
|
||||
batch_size, 1536, dtype=torch.float32
|
||||
),
|
||||
f"{TOPREWARD_FEATURE_PREFIX}video_grid_thw": torch.ones(batch_size, 3, dtype=torch.long),
|
||||
f"{TOPREWARD_FEATURE_PREFIX}mm_token_type_ids": torch.zeros_like(input_ids),
|
||||
}
|
||||
)
|
||||
if omit is not None:
|
||||
batch.pop(f"{TOPREWARD_FEATURE_PREFIX}{omit}", None)
|
||||
return batch
|
||||
|
||||
|
||||
def _terminal_labels(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
labels = torch.full_like(input_ids, -100)
|
||||
labels[:, -1] = input_ids[:, -1]
|
||||
return labels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry + factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_topreward_config_registered():
|
||||
assert "topreward" in RewardModelConfig.get_known_choices()
|
||||
assert RewardModelConfig.get_choice_class("topreward") is TOPRewardConfig
|
||||
assert isinstance(make_reward_model_config("topreward", device="cpu"), TOPRewardConfig)
|
||||
|
||||
|
||||
def test_topreward_factory_returns_in_tree_class():
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
assert get_reward_model_class("topreward") is TOPRewardModel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_topreward_config_rejects_zero_max_frames():
|
||||
with pytest.raises(ValueError, match="max_frames must be >= 1"):
|
||||
TOPRewardConfig(device="cpu", max_frames=0)
|
||||
|
||||
|
||||
def test_topreward_config_rejects_non_positive_fps():
|
||||
with pytest.raises(ValueError, match="fps must be > 0"):
|
||||
TOPRewardConfig(device="cpu", fps=0.0)
|
||||
|
||||
|
||||
def test_topreward_config_rejects_suffix_without_instruction_placeholder():
|
||||
with pytest.raises(ValueError, match=r"\{instruction\}"):
|
||||
TOPRewardConfig(device="cpu", prompt_suffix_template="no placeholder here")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_reward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_returns_one_scalar_per_sample(monkeypatch):
|
||||
"""``compute_reward`` must return a ``(B,)`` float32 tensor with one
|
||||
log-prob reward per sample, consuming pre-encoded Qwen-VL tensors."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (2, 10))
|
||||
attention_mask = torch.ones(2, 10, dtype=torch.long)
|
||||
labels = _terminal_labels(input_ids)
|
||||
|
||||
batch = _make_batch(input_ids, attention_mask, labels)
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert rewards.shape == (2,)
|
||||
assert rewards.dtype == torch.float32
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_applies_success_threshold(monkeypatch):
|
||||
"""When ``success_threshold`` is finite, the model returns binary success."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu", success_threshold=0.0)
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (2, 10))
|
||||
attention_mask = torch.ones(2, 10, dtype=torch.long)
|
||||
labels = _terminal_labels(input_ids)
|
||||
|
||||
batch = _make_batch(input_ids, attention_mask, labels)
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert rewards.shape == (2,)
|
||||
assert set(rewards.tolist()).issubset({0.0, 1.0})
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_errors_when_inputs_missing(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
with pytest.raises(KeyError, match=r"observation\.topreward\.input_ids"):
|
||||
model.compute_reward(_make_batch(torch.randint(0, 100, (1, 10)), omit="input_ids"))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_errors_when_labels_missing(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (1, 10))
|
||||
with pytest.raises(KeyError, match=r"observation\.topreward\.labels"):
|
||||
model.compute_reward(_make_batch(input_ids, labels=None))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_requires_all_encoder_keys(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (1, 10))
|
||||
labels = _terminal_labels(input_ids)
|
||||
required_encoder_keys = set(TOPREWARD_INPUT_KEYS) - {"input_ids", "labels"}
|
||||
|
||||
for key in required_encoder_keys:
|
||||
with pytest.raises(KeyError, match=rf"observation\.topreward\.{key}"):
|
||||
model.compute_reward(_make_batch(input_ids, labels=labels, omit=key))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save / load — config-only checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_save_pretrained_writes_only_config_json(monkeypatch, tmp_path):
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(
|
||||
device="cpu",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
fps=4.0,
|
||||
image_key="observation.images.front",
|
||||
)
|
||||
model = TOPRewardModel(cfg)
|
||||
model.save_pretrained(str(tmp_path))
|
||||
|
||||
assert (tmp_path / CONFIG_NAME).exists()
|
||||
assert not (tmp_path / SAFETENSORS_SINGLE_FILE).exists()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_from_pretrained_local_dir_roundtrips_config(monkeypatch, tmp_path):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(
|
||||
device="cpu",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
fps=4.0,
|
||||
image_key="observation.images.front",
|
||||
add_chat_template=True,
|
||||
success_threshold=-1.5,
|
||||
)
|
||||
TOPRewardModel(cfg).save_pretrained(str(tmp_path))
|
||||
|
||||
reloaded = TOPRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
assert isinstance(reloaded.config, TOPRewardConfig)
|
||||
assert reloaded.config.vlm_name == "Qwen/Qwen3-VL-8B-Instruct"
|
||||
assert reloaded.config.fps == 4.0
|
||||
assert reloaded.config.image_key == "observation.images.front"
|
||||
assert reloaded.config.add_chat_template is True
|
||||
assert reloaded.config.success_threshold == -1.5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_is_not_trainable(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
assert model.is_trainable is False
|
||||
with pytest.raises(NotImplementedError, match="not trainable"):
|
||||
model.forward({"x": torch.zeros(1)})
|
||||
80
tests/rewards/test_topreward.py
Normal file
80
tests/rewards/test_topreward.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end TOPReward smoke test with the real Qwen3-VL model."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig # noqa: E402
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel # noqa: E402
|
||||
from lerobot.rewards.topreward.processor_topreward import ( # noqa: E402
|
||||
TOPREWARD_FEATURE_PREFIX,
|
||||
TOPREWARD_INPUT_KEYS,
|
||||
make_topreward_pre_post_processors,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires downloading and loading Qwen3-VL and is not meant for CI",
|
||||
)
|
||||
|
||||
|
||||
def _make_dummy_topreward_batch(image_key: str, task_key: str) -> dict[str, object]:
|
||||
num_frames = 4
|
||||
image_size = 64
|
||||
frames = torch.zeros(1, num_frames, 3, image_size, image_size, dtype=torch.uint8)
|
||||
for frame_idx in range(num_frames):
|
||||
frames[0, frame_idx, 0].fill_(min(frame_idx * 48, 255))
|
||||
frames[0, frame_idx, 1].fill_(96)
|
||||
frames[0, frame_idx, 2].fill_(192)
|
||||
|
||||
return {
|
||||
image_key: frames,
|
||||
task_key: ["pick up the red cube"],
|
||||
}
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_topreward_full_qwen3vl_preprocessor_to_compute_reward():
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
max_frames=4,
|
||||
fps=2.0,
|
||||
max_input_length=4096,
|
||||
)
|
||||
|
||||
preprocessor, _ = make_topreward_pre_post_processors(cfg)
|
||||
encoded_batch = preprocessor(_make_dummy_topreward_batch(cfg.image_key, cfg.task_key))
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
assert f"{TOPREWARD_FEATURE_PREFIX}{key}" in encoded_batch
|
||||
|
||||
model = TOPRewardModel(cfg)
|
||||
try:
|
||||
model.to(cfg.device)
|
||||
model.eval()
|
||||
rewards = model.compute_reward(encoded_batch)
|
||||
finally:
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert rewards.shape == (1,)
|
||||
assert rewards.dtype == torch.float32
|
||||
assert torch.isfinite(rewards).all()
|
||||
246
tests/rewards/test_topreward_processor.py
Normal file
246
tests/rewards/test_topreward_processor.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for TOPReward's pre-processing helpers and encoder step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.rewards.topreward.processor_topreward import (
|
||||
TOPREWARD_FEATURE_PREFIX,
|
||||
TOPREWARD_INPUT_KEYS,
|
||||
_expand_tasks,
|
||||
_prepare_video_batch,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _prepare_video_batch — raw image/video batch -> (B, T, C, H, W) uint8
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_video_batch_batched_chw_float_is_converted_to_uint8():
|
||||
video = torch.rand(2, 4, 3, 8, 8)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert tensor.shape == (2, 4, 3, 8, 8)
|
||||
assert tensor.dtype == torch.uint8
|
||||
assert tensor.min() >= 0 and tensor.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_video_batch_batched_thwc_uint8_is_permuted_to_channel_first():
|
||||
video = torch.randint(0, 256, (2, 3, 8, 8, 3), dtype=torch.uint8)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert tensor.shape == (2, 3, 3, 8, 8)
|
||||
assert tensor.dtype == torch.uint8
|
||||
|
||||
|
||||
def test_prepare_video_batch_max_frames_tail_crops_recent_frames():
|
||||
video = torch.zeros(1, 10, 3, 4, 4)
|
||||
for t in range(10):
|
||||
video[:, t] = t / 9.0
|
||||
|
||||
tensor = _prepare_video_batch(video, max_frames=3)
|
||||
|
||||
assert tensor.shape == (1, 3, 3, 4, 4)
|
||||
assert int(tensor[0, 0, 0, 0, 0]) == int(7 / 9 * 255)
|
||||
assert int(tensor[0, -1, 0, 0, 0]) == 255
|
||||
|
||||
|
||||
def test_prepare_video_batch_rejects_3d_input():
|
||||
with pytest.raises(ValueError, match="Expected TOPReward frames"):
|
||||
_prepare_video_batch(torch.zeros(4, 8, 8), max_frames=None)
|
||||
|
||||
|
||||
def test_prepare_video_batch_floats_above_one_are_rescaled_and_clipped():
|
||||
video = torch.full((1, 1, 3, 2, 2), 5.0)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert tensor.shape == (1, 1, 3, 2, 2)
|
||||
assert int(tensor.max()) == 255
|
||||
|
||||
|
||||
def test_prepare_video_batch_clips_very_large_floats_to_uint8_max():
|
||||
video = torch.full((1, 1, 3, 2, 2), 300.0)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert int(tensor.max()) == 255
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _expand_tasks — string / list / tuple broadcasting to batch size
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_expand_tasks_string_is_broadcast_to_batch_size():
|
||||
assert _expand_tasks("pick up", batch_size=3, default=None) == ["pick up", "pick up", "pick up"]
|
||||
|
||||
|
||||
def test_expand_tasks_list_of_matching_size_passes_through():
|
||||
assert _expand_tasks(["a", "b", "c"], batch_size=3, default=None) == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_expand_tasks_tuple_is_normalised_to_list():
|
||||
assert _expand_tasks(("a", "b"), batch_size=2, default=None) == ["a", "b"]
|
||||
|
||||
|
||||
def test_expand_tasks_single_element_list_is_broadcast():
|
||||
assert _expand_tasks(["only one"], batch_size=3, default=None) == ["only one"] * 3
|
||||
|
||||
|
||||
def test_expand_tasks_size_mismatch_raises():
|
||||
with pytest.raises(ValueError, match="Expected 3 tasks"):
|
||||
_expand_tasks(["a", "b"], batch_size=3, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_missing_uses_default():
|
||||
assert _expand_tasks(None, batch_size=2, default="fallback") == ["fallback", "fallback"]
|
||||
|
||||
|
||||
def test_expand_tasks_missing_without_default_raises():
|
||||
with pytest.raises(KeyError, match="task description"):
|
||||
_expand_tasks(None, batch_size=1, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_wrong_type_raises():
|
||||
with pytest.raises(TypeError, match="must be a string or list"):
|
||||
_expand_tasks(42, batch_size=1, default=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder step — stubbed AutoProcessor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _skip_if_topreward_extras_missing(func):
|
||||
func = skip_if_package_missing("transformers")(func)
|
||||
return func
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
eos_token = "<|endoftext|>"
|
||||
pad_token = "<|endoftext|>"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return {"input_ids": torch.zeros(1, 10, dtype=torch.long)}
|
||||
|
||||
|
||||
class _FakeAutoProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = _FakeTokenizer()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
def apply_chat_template(self, messages, **kwargs): # noqa: ARG002
|
||||
return "fake_prompt_text"
|
||||
|
||||
def __call__(self, text=None, images=None, videos=None, **kwargs): # noqa: ARG002
|
||||
seq_len = 10
|
||||
batch_size = len(text) if isinstance(text, list) else 1
|
||||
return {
|
||||
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
|
||||
"pixel_values_videos": torch.zeros(batch_size, 1536, dtype=torch.float32),
|
||||
"video_grid_thw": torch.ones(batch_size, 3, dtype=torch.long),
|
||||
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def _build_step(monkeypatch, **overrides):
|
||||
from lerobot.rewards.topreward import processor_topreward
|
||||
|
||||
monkeypatch.setattr(processor_topreward, "AutoProcessor", _FakeAutoProcessor)
|
||||
return processor_topreward.TOPRewardEncoderProcessorStep(**overrides)
|
||||
|
||||
|
||||
def _make_transition(observation: dict, complementary: dict | None = None) -> dict:
|
||||
transition: dict = {TransitionKey.OBSERVATION: observation}
|
||||
if complementary is not None:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary
|
||||
return transition
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_emits_input_ids_and_labels(monkeypatch):
|
||||
"""The processor must emit Qwen-VL tensors including ``input_ids`` and
|
||||
``labels`` under the ``observation.topreward.*`` namespace."""
|
||||
step = _build_step(monkeypatch)
|
||||
|
||||
frames_batch = torch.zeros(2, 4, 3, 8, 8)
|
||||
out = step(
|
||||
_make_transition(
|
||||
observation={"observation.images.top": frames_batch},
|
||||
complementary={"task": ["pick", "place"]},
|
||||
)
|
||||
)
|
||||
|
||||
obs_out = out[TransitionKey.OBSERVATION]
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
assert f"{TOPREWARD_FEATURE_PREFIX}{key}" in obs_out
|
||||
|
||||
input_ids = obs_out[f"{TOPREWARD_FEATURE_PREFIX}input_ids"]
|
||||
labels = obs_out[f"{TOPREWARD_FEATURE_PREFIX}labels"]
|
||||
assert labels.dtype == torch.long
|
||||
assert labels.shape == (2, 10)
|
||||
assert labels[:, :-1].eq(-100).all()
|
||||
assert labels[:, -1].equal(input_ids[:, -1])
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_get_config_roundtrips_user_fields(monkeypatch):
|
||||
step = _build_step(
|
||||
monkeypatch,
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
image_key="observation.images.cam_top",
|
||||
task_key="task",
|
||||
default_task="do the thing",
|
||||
max_frames=8,
|
||||
fps=4.0,
|
||||
add_chat_template=True,
|
||||
max_length=2048,
|
||||
)
|
||||
|
||||
cfg = step.get_config()
|
||||
assert cfg["vlm_name"] == "Qwen/Qwen3-VL-8B-Instruct"
|
||||
assert cfg["image_key"] == "observation.images.cam_top"
|
||||
assert cfg["default_task"] == "do the thing"
|
||||
assert cfg["max_frames"] == 8
|
||||
assert cfg["fps"] == 4.0
|
||||
assert cfg["add_chat_template"] is True
|
||||
assert cfg["max_length"] == 2048
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_transform_features_is_identity(monkeypatch):
|
||||
step = _build_step(monkeypatch)
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.images.top": PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL),
|
||||
}
|
||||
}
|
||||
assert step.transform_features(features) == features
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_rejects_missing_image_key(monkeypatch):
|
||||
step = _build_step(monkeypatch, image_key="observation.images.top")
|
||||
with pytest.raises(KeyError, match="image key"):
|
||||
step(_make_transition(observation={}, complementary={"task": "pick"}))
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.parser import (
|
||||
_config_path_args,
|
||||
_config_yaml_overrides,
|
||||
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
|
||||
|
||||
|
||||
def test_extract_path_fields_from_yaml():
|
||||
"""Test that policy.path is extracted from a YAML config and removed."""
|
||||
"""Test that policy.path is extracted from a YAML config and the policy block
|
||||
is removed entirely (siblings are captured separately as cli_overrides)."""
|
||||
config = {
|
||||
"dataset": {"repo_id": "lerobot/pusht"},
|
||||
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
||||
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
# Path should be extracted and stored
|
||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||
|
||||
# Cleaned config should not have the path field
|
||||
# Cleaned config should not have the policy block at all -- draccus must not
|
||||
# try to decode it as PreTrainedConfig; the actual config comes from
|
||||
# from_pretrained(path) with the captured overrides applied on top.
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = yaml.safe_load(f)
|
||||
assert "path" not in cleaned["policy"]
|
||||
assert cleaned["policy"]["type"] == "smolvla"
|
||||
assert cleaned["policy"]["push_to_hub"] is False
|
||||
assert "policy" not in cleaned
|
||||
|
||||
# Original dataset should be untouched
|
||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||
|
||||
# Sibling overrides (excluding type/path) captured for from_pretrained.
|
||||
overrides = get_yaml_overrides("policy")
|
||||
assert any("push_to_hub=false" in o for o in overrides)
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
def test_extract_path_fields_from_json():
|
||||
"""Test that policy.path is extracted from a JSON config."""
|
||||
"""Test that policy.path is extracted from a JSON config and the policy
|
||||
block is removed entirely."""
|
||||
config = {
|
||||
"policy": {"type": "act", "path": "some/local/path"},
|
||||
}
|
||||
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
assert _config_path_args["policy"] == "some/local/path"
|
||||
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = json.load(f)
|
||||
assert "path" not in cleaned["policy"]
|
||||
assert "policy" not in cleaned
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
def test_extract_no_path_returns_original():
|
||||
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
|
||||
args = _flatten_to_cli_args(d)
|
||||
assert "--optimizer.use_warmup=true" in args
|
||||
assert "--optimizer.lr=0.01" in args
|
||||
|
||||
|
||||
def test_extract_removes_field_with_siblings_and_no_type():
|
||||
"""Regression: when policy.path has siblings but no type:, the entire policy
|
||||
block must still be removed from the cleaned config. Otherwise draccus tries
|
||||
to decode the leftover dict as PreTrainedConfig and crashes on the missing
|
||||
type discriminator.
|
||||
"""
|
||||
config = {
|
||||
"dataset": {"repo_id": "lerobot/pusht"},
|
||||
"policy": {
|
||||
"path": "lerobot/smolvla_base",
|
||||
"n_action_steps": 10,
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
yaml.dump(config, f)
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = yaml.safe_load(f) or {}
|
||||
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
|
||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||
overrides = get_yaml_overrides("policy")
|
||||
assert any("n_action_steps=10" in o for o in overrides)
|
||||
assert any("dtype=bfloat16" in o for o in overrides)
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DummyNested:
|
||||
foo: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DummyConfig:
|
||||
nested: _DummyNested = field(default_factory=_DummyNested)
|
||||
other: str = "default"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls):
|
||||
return ["nested"]
|
||||
|
||||
|
||||
def test_wrap_uses_cleaned_config_for_draccus_parse():
|
||||
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
|
||||
file but must propagate that to the draccus.parse fallback branch. Without
|
||||
the fix, cli_args still contains --config_path=<original> and draccus reads
|
||||
the original YAML with `path:` still in it, crashing on the unknown field.
|
||||
"""
|
||||
config = {
|
||||
"nested": {"path": "some/checkpoint", "foo": 42},
|
||||
"other": "set-via-yaml",
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
yaml.dump(config, f)
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: _DummyConfig) -> _DummyConfig:
|
||||
captured["cfg"] = cfg
|
||||
return cfg
|
||||
|
||||
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
|
||||
main()
|
||||
|
||||
assert captured["cfg"].other == "set-via-yaml"
|
||||
assert _config_path_args["nested"] == "some/checkpoint"
|
||||
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
|
||||
# class (a real PreTrainedConfig would now load the checkpoint and apply
|
||||
# the captured yaml_overrides via from_pretrained()).
|
||||
assert captured["cfg"].nested.foo == 0
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
25
uv.lock
generated
25
uv.lock
generated
@@ -2915,6 +2915,11 @@ metaworld = [
|
||||
{ name = "scipy" },
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
]
|
||||
molmoact2 = [
|
||||
{ name = "peft" },
|
||||
{ name = "scipy" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
motorbridge-dep = [
|
||||
{ name = "motorbridge" },
|
||||
]
|
||||
@@ -3009,6 +3014,9 @@ test = [
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-timeout" },
|
||||
]
|
||||
topreward = [
|
||||
{ name = "transformers" },
|
||||
]
|
||||
training = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "av" },
|
||||
@@ -3039,6 +3047,11 @@ video-benchmark = [
|
||||
viz = [
|
||||
{ name = "rerun-sdk" },
|
||||
]
|
||||
vla-jepa = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
wallx = [
|
||||
{ name = "peft" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
@@ -3107,6 +3120,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
@@ -3128,6 +3142,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["molmoact2"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["motorbridge-dep"], marker = "extra == 'rebot'" },
|
||||
{ name = "lerobot", extras = ["motorbridge-smart-servo-dep"], marker = "extra == 'rebot'" },
|
||||
{ name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" },
|
||||
@@ -3135,6 +3150,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["peft"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'molmoact2'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["phone"], marker = "extra == 'all'" },
|
||||
@@ -3154,6 +3170,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
||||
@@ -3162,27 +3179,33 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'aloha'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'libero'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'metaworld'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'molmoact2'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'phone'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
||||
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
|
||||
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
|
||||
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
|
||||
@@ -3244,7 +3267,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user