diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 527cb7e63..1d4d9e770 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,6 +59,8 @@ title: π₀-FAST (Pi0Fast) - local: pi05 title: π₀.₅ (Pi05) + - local: molmoact2 + title: MolmoAct2 - local: eo1 title: EO-1 - local: groot diff --git a/docs/source/molmoact2.mdx b/docs/source/molmoact2.mdx new file mode 100644 index 000000000..ddd178acd --- /dev/null +++ b/docs/source/molmoact2.mdx @@ -0,0 +1,433 @@ +# MolmoAct2 Policy + +MolmoAct2 is the LeRobot policy implementation of +[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot +training, evaluation, checkpointing, and dataset interfaces for easier use with +LeRobot datasets. + +This implementation currently supports training and evaluation for the regular +MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is +not included in this LeRobot policy yet and is coming soon. + +For the original MolmoAct2 training code used for the experiments reported in +the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2). + +## Installation Requirements + +Install LeRobot with the MolmoAct2 optional dependencies: + +```bash +pip install -e ".[molmoact2]" +``` + +To run the models in this repository, you need an NVIDIA GPU. The measurements +below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7 +padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use +`gradient_checkpointing=true` and include the forward pass, backward pass, +gradient clipping, optimizer step, and optimizer state allocation. Values are +peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for +dataloader workers, CUDA context, and fragmentation. + +Multi-GPU training through `accelerate` increases throughput and global batch +size, but this LeRobot port does not currently expose the original MolmoAct2 +`fsdp_devices` model-parallel training path. The current training script has +not been tested for multi-node training. + +| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 | +| ------------------------------------------------ | ----------------: | -----------------: | -----------------: | +| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - | +| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB | +| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB | +| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB | + +The repo has been tested with Ubuntu 22.04. + +## Usage + +To use MolmoAct2 in a LeRobot training config, set: + +```python +policy.type=molmoact2 +``` + +## Training + +MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face +checkpoint format or from a checkpoint already saved by LeRobot. Both routes use +the same LeRobot training loop, dataset transforms, checkpoint saving, and +logging. The difference is only how the initial policy weights and processor +state are loaded. + +### Training With Original MolmoAct2 Weight + +Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint, +for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load +the original HF model files, then build its own policy processor from the +dataset metadata and the policy options below. + +The command below shows full fine-tuning on the merged LIBERO dataset. It uses +bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image +augmentation, and LeRobot's checkpointing/logging path. + +```bash +accelerate launch \ + --num_processes=8 \ + --mixed_precision=bf16 \ + -m lerobot.scripts.lerobot_train \ + --dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.video_backend=pyav \ + --dataset.image_transforms.enable=true \ + --policy.type=molmoact2 \ + --policy.checkpoint_path=allenai/MolmoAct2-LIBERO \ + --policy.device=cuda \ + --policy.action_mode=both \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.setup_type="single franka robotic arm in libero" \ + --policy.control_mode="delta end-effector pose" \ + --policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \ + --policy.model_dtype=bfloat16 \ + --policy.num_flow_timesteps=8 \ + --policy.gradient_checkpointing=true \ + --policy.freeze_embedding=true \ + --policy.normalize_gripper=false \ + --policy.enable_knowledge_insulation=false \ + --policy.push_to_hub=false \ + --wandb.enable=true \ + --wandb.entity= \ + --wandb.project= \ + --job_name= \ + --output_dir=outputs/ \ + --steps=10000 \ + --batch_size=32 \ + --num_workers=4 \ + --log_freq=20 \ + --eval_freq=-1 \ + --save_checkpoint=true \ + --save_freq=2000 +``` + +### Training With LeRobot MolmoAct2 Weight + +Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by +LeRobot, either from a local `pretrained_model` directory or from the Hub. This +restores the saved LeRobot policy config, model weights, processor, and +normalization statistics. You can still override training-time options such as +`batch_size`, `steps`, LoRA flags, or `policy.action_mode`. + +```bash +accelerate launch \ + --num_processes=8 \ + --mixed_precision=bf16 \ + -m lerobot.scripts.lerobot_train \ + --dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.video_backend=pyav \ + --dataset.image_transforms.enable=true \ + --policy.path=/path/to/pretrained_model \ + --policy.device=cuda \ + --policy.action_mode=both \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.model_dtype=bfloat16 \ + --policy.num_flow_timesteps=8 \ + --policy.gradient_checkpointing=true \ + --wandb.enable=true \ + --wandb.entity= \ + --wandb.project= \ + --job_name= \ + --output_dir=outputs/ \ + --steps=10000 \ + --batch_size=32 \ + --num_workers=4 \ + --log_freq=20 \ + --eval_freq=-1 \ + --save_checkpoint=true \ + --save_freq=2000 +``` + +### Common Practices + +For fine-tuning on a comparatively small dataset, such as a single LIBERO suite +or a real-world dataset with less than 200 demonstrations, a global batch size of +16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both +cases, we intentionally keep the action expert fully trainable, which we found +to be crucial for model performance. For larger fine-tuning datasets, larger +global batch sizes and full fine-tuning are usually preferred. + +### Common Policy Options + +- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from. + Use this for released MolmoAct2 weights. +- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints + created by LeRobot training. +- `policy.action_mode`: training target, one of `continuous`, `discrete`, or + `both`. `both` trains the flow-matching action expert and the discrete + action-token loss. +- `policy.train_action_expert_only`: trains only parameters whose names contain + `action_expert`. It requires `policy.action_mode=continuous`. +- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use + `policy.enable_lora_action_expert=true` only if LoRA should also cover action + expert linear layers. When `policy.enable_lora_action_expert=false`, the + action expert base weights remain fully trainable while the VLM is trained + through LoRA adapters. When `policy.enable_lora_action_expert=true`, the + action expert is also adapter-tuned instead of fully fine-tuned. +- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert + context K/V states before the action loss. The default is `false`. +- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use + `10`. This LeRobot port overrides the loaded checkpoint's + `max_action_horizon` with this value. +- `policy.n_action_steps`: number of actions consumed from each predicted + chunk before querying the policy again. For LIBERO, set it to `chunk_size`. +- `policy.setup_type`: text inserted into the prompt to describe the robot and + scene, e.g. `single franka robotic arm in libero`. More examples are listed + in the `metadata_by_tag` entries of + [`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json). +- `policy.control_mode`: text inserted into the prompt to describe the action + space, e.g. `delta end-effector pose` or `absolute joint pose`. +- `policy.image_keys`: ordered LeRobot image observation keys passed to the + processor. +- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`, + `bfloat16`, or `float16`. Use `bfloat16` for normal training. +- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per + example during training. We use `8` for fine-tuning. +- `policy.num_inference_steps`: optional override for continuous action + generation steps at inference time. +- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path + to reduce activation memory. +- `policy.freeze_embedding`: freezes input embeddings. The default is `true`. +- `policy.normalize_gripper`: controls whether gripper dimensions are included + in state/action quantile normalization. The default is `false`. +- `policy.normalize_language`: normalizes task strings before prompt + construction. The default is `true`. +- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss. + Released checkpoints use `policy.expected_max_action_dim=32`. +- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to + infer it from images, state dimension, action dimension, action horizon, and + discrete-action mode. + +### Learning Rates + +MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2 +fine-tuning experiments. + +- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM, + `policy.optimizer_vit_lr=5e-6` for the vision tower, + `policy.optimizer_connector_lr=5e-6` for image connector layers, and + `policy.optimizer_action_expert_lr=5e-5` for the action expert. +- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter + groups to `5e-5` when `policy.enable_lora_vlm=true`. By default, + `policy.enable_lora_action_expert=false`, so the action expert is still fully + fine-tuned with `policy.optimizer_action_expert_lr`. If + `policy.enable_lora_action_expert=true`, the action expert is trained through + LoRA adapters instead. +- Action-expert-only fine-tuning trains only the action expert and uses + `policy.optimizer_action_expert_lr=5e-5`. + +You can override the full fine-tuning and action-expert learning rates with +`policy.optimizer_lr`, `policy.optimizer_vit_lr`, +`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`. +Scheduler settings can be changed with `policy.scheduler_warmup_steps`, +`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`. + +### Dataset Quantile Statistics + +MolmoAct2 defaults to quantile normalization for state and action features. If +your dataset has not been converted with quantile statistics, you can add them +with: + +```bash +python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ + --repo-id=your_dataset +``` + +Alternatively, train MolmoAct2 with mean/std normalization: + +```bash +--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' +``` + +## Evaluation + +Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2 +HF checkpoints. For LIBERO replication, keep the EGL rendering environment +fixed and use `policy.per_episode_seed=true`. + +**Important:** We found that `num_steps_wait=10` does not reliably let the +LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation +results reported here use `num_steps_wait=50`. + +### Evaluation With LeRobot MolmoAct2 Weight + +Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and +normalization statistics are restored together with the model. + +```bash +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +lerobot-eval \ + --policy.path=allenai/MolmoAct2-LIBERO-LeRobot \ + --policy.inference_action_mode=continuous \ + --policy.model_dtype=bfloat16 \ + --policy.use_amp=true \ + --policy.enable_inference_cuda_graph=true \ + --policy.device=cuda \ + --policy.per_episode_seed=true \ + --policy.eval_seed=1000 \ + --env.type=libero \ + --env.task=libero_10,libero_goal,libero_object,libero_spatial \ + --env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \ + --eval.batch_size=1 \ + --eval.n_episodes=50 \ + --seed=1000 +``` + +### Evaluation With Original MolmoAct2 Weight + +You can evaluate a released Hugging Face checkpoint directly without first +converting it to a LeRobot checkpoint. In this case, set +`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`. +For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state +normalization statistics, action horizon, prompt metadata, and image-key order +from the checkpoint's `norm_stats.json`. + +To fully replicate the MolmoAct2 paper results with released Hugging Face +checkpoints, we recommend using the v0.5.1-pinned +[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference) +branch. That branch matches the original evaluation settings used for the +reported numbers. + +```bash +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +lerobot-eval \ + --policy.type=molmoact2 \ + --policy.checkpoint_path=allenai/MolmoAct2-LIBERO \ + --policy.norm_tag=libero \ + --policy.inference_action_mode=continuous \ + --policy.model_dtype=float32 \ + --policy.use_amp=false \ + --policy.enable_inference_cuda_graph=true \ + --policy.device=cuda \ + --policy.per_episode_seed=true \ + --policy.eval_seed=1000 \ + --env.type=libero \ + --env.task=libero_goal \ + --env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \ + --eval.batch_size=1 \ + --eval.n_episodes=50 \ + --seed=1000 +``` + +Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the +full LIBERO suite. The same command works for other released MolmoAct2 +checkpoints as long as the requested `policy.norm_tag` exists in that +checkpoint's `norm_stats.json`. + +### Common Evaluation Options + +- `policy.inference_action_mode`: required for rollout. Use `continuous` for + flow-matching inference or `discrete` for action-token inference. It must be + compatible with the training-time `policy.action_mode` saved in the + checkpoint. +- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints + saved by LeRobot. +- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo. + Use this with `policy.type=molmoact2` and `policy.norm_tag`. +- `policy.norm_tag`: selects normalization statistics, prompt metadata, + image-key order, and action horizon from the original checkpoint's + `norm_stats.json`. It is required for direct original-HF checkpoint + evaluation. +- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal + GPU evaluation. Use `float32` only when you explicitly want fp32 inference. +- `policy.use_amp`: runs the policy forward under autocast during eval. For + `model_dtype=bfloat16`, keep this enabled. +- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA + graph path for faster repeated continuous-action rollout. +- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous + action generation deterministic per episode for replication. +- `env.task`: comma-separated LIBERO suites or a single suite. Use + `libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark. +- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected + by the policy processor. + +## Performance Results + +### LIBERO Benchmark Results + +MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To +compare and test its LeRobot implementation, we fine-tuned +[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO) +for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on +8 H100 GPUs, then compared the results to the original MolmoAct2 reference +results. + +The LeRobot fine-tuned checkpoint reported here is available at +[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot) +and was trained on +[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset). + +| Benchmark | LeRobot Implementation | MolmoAct2 Original | +| -------------- | ---------------------: | -----------------: | +| LIBERO Spatial | 98.4% | 97.8% | +| LIBERO Object | 100.0% | 100.0% | +| LIBERO Goal | 98.0% | 97.8% | +| LIBERO 10 | 96.6% | 93.2% | +| Average | 98.25% | 97.20% | + +These results demonstrate MolmoAct2's strong performance across diverse robotic +manipulation tasks. To reproduce them, follow the instructions in the LIBERO +evaluation section. + +## Differences From the Original Implementation + +This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's +dataset, training, evaluation, checkpoint, and logging infrastructure. The main +differences from the original training repository are: + +- The original paper training stack loads the model in fp32 and trains under + mixed precision. This LeRobot port usually loads the checkpoint directly in + `policy.model_dtype=bfloat16` for lower memory use. +- The original repository uses its own FSDP/model-parallel training path. The + LeRobot port uses the standard LeRobot/Accelerate training path and has not + been tested for multi-node training. +- The original repository supports sequence packing. The LeRobot port trains on + one LeRobot sample per item and pads to an inferred fixed sequence budget. +- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving, + dataset transforms, image augmentation, and Weights & Biases logging + conventions. +- The original training path supports mixed action horizons by padding to + `max_action_horizon` and masking padded horizon slots in the action expert + self-attention. This is useful when training across datasets with different + control frequencies. The LeRobot port currently targets single-dataset + fine-tuning, so `policy.chunk_size` overrides the checkpoint + `max_action_horizon` and horizon masking is not implemented yet. Support for + this mixed-horizon path is planned. + +## Citation + +```bibtex +@misc{fang2026molmoact2actionreasoningmodels, + title={MolmoAct2: Action Reasoning Models for Real-world Deployment}, + author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna}, + year={2026}, + eprint={2605.02881}, + archivePrefix={arXiv}, + primaryClass={cs.RO}, + url={https://arxiv.org/abs/2605.02881}, +} +``` + +## License + +This model is licensed under Apache 2.0. It is intended for research and +educational use in accordance with +[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use), +consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2). diff --git a/docs/source/policy_molmoact2_README.md b/docs/source/policy_molmoact2_README.md new file mode 100644 index 000000000..df3a6341e --- /dev/null +++ b/docs/source/policy_molmoact2_README.md @@ -0,0 +1,39 @@ +# MolmoAct2 + +This repository contains the LeRobot policy implementation of +[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for +training, evaluation, checkpointing, and dataset compatibility. + +This implementation currently supports training and evaluation for the regular +MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is +not included in this LeRobot policy yet and is coming soon. + +For the original MolmoAct2 training code used for the experiments reported in +the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2). + +## LIBERO Evaluation + +Important: we found that `num_steps_wait=10` does not reliably let the LIBERO +scene stabilize and can degrade measured success. All LIBERO evaluation results +reported for this LeRobot implementation use `num_steps_wait=50`. + +## Citation + +```bibtex +@misc{fang2026molmoact2actionreasoningmodels, + title={MolmoAct2: Action Reasoning Models for Real-world Deployment}, + author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna}, + year={2026}, + eprint={2605.02881}, + archivePrefix={arXiv}, + primaryClass={cs.RO}, + url={https://arxiv.org/abs/2605.02881}, +} +``` + +## License + +This model is licensed under Apache 2.0. It is intended for research and +educational use in accordance with +[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use), +consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2). diff --git a/pyproject.toml b/pyproject.toml index 264297c5e..a6785c564 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,7 @@ wallx = [ "lerobot[qwen-vl-utils-dep]", ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] +molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"] multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"] groot = [ @@ -275,6 +276,7 @@ all = [ "lerobot[multi_task_dit]", "lerobot[wallx]", "lerobot[pi]", + "lerobot[molmoact2]", "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[xvla]", @@ -405,8 +407,11 @@ default.extend-ignore-identifiers-re = [ "ein", "thw", "inpt", + "arange", + "is_compileable", "ROBOTIS", - "OT_VALUE" + "OT_VALUE", + "VanderBilt" ] # TODO: Uncomment when ready to use diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 3a6b8e5d2..68d23c9ca 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig from .groot.configuration_groot import GrootConfig as GrootConfig +from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig @@ -43,6 +44,7 @@ __all__ = [ "EO1Config", "GaussianActorConfig", "GrootConfig", + "MolmoAct2Config", "MultiTaskDiTConfig", "PI0Config", "PI0FastConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 8937bc6ae..05fda05d8 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig from .groot.configuration_groot import GrootConfig +from .molmoact2.configuration_molmoact2 import MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config from .pi05.configuration_pi05 import PI05Config @@ -88,7 +89,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x". + "multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x", + "molmoact2". Returns: The policy class corresponding to the given name. @@ -151,6 +153,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .eo1.modeling_eo1 import EO1Policy return EO1Policy + elif name == "molmoact2": + from .molmoact2.modeling_molmoact2 import MolmoAct2Policy + + return MolmoAct2Policy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -168,7 +174,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor", - "smolvla", "wall_x". + "smolvla", "wall_x", "molmoact2". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -203,6 +209,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return WallXConfig(**kwargs) elif policy_type == "eo1": return EO1Config(**kwargs) + elif policy_type == "molmoact2": + return MolmoAct2Config(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -231,6 +239,7 @@ class ProcessorConfigKwargs(TypedDict, total=False): preprocessor_overrides: dict[str, Any] | None postprocessor_overrides: dict[str, Any] | None dataset_stats: dict[str, dict[str, torch.Tensor]] | None + dataset_meta: Any | None def make_pre_post_processors( @@ -414,6 +423,15 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, MolmoAct2Config): + from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors + + processors = make_molmoact2_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), + ) + else: try: processors = _make_processors_from_policy_config( @@ -499,6 +517,10 @@ def make_policy( action_names = ds_meta.features.get(ACTION, {}).get("names") if action_names is not None: cfg.action_feature_names = list(action_names) + if ds_meta is not None: + set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None) + if callable(set_dataset_feature_metadata): + set_dataset_feature_metadata(ds_meta.features) kwargs["config"] = cfg diff --git a/src/lerobot/policies/molmoact2/README.md b/src/lerobot/policies/molmoact2/README.md new file mode 120000 index 000000000..ef419516d --- /dev/null +++ b/src/lerobot/policies/molmoact2/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_molmoact2_README.md \ No newline at end of file diff --git a/src/lerobot/policies/molmoact2/__init__.py b/src/lerobot/policies/molmoact2/__init__.py new file mode 100644 index 000000000..bfef53bb2 --- /dev/null +++ b/src/lerobot/policies/molmoact2/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_molmoact2 import MolmoAct2Config +from .modeling_molmoact2 import MolmoAct2Policy +from .processor_molmoact2 import make_molmoact2_pre_post_processors + +__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"] diff --git a/src/lerobot/policies/molmoact2/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/configuration_molmoact2.py new file mode 100644 index 000000000..de2585281 --- /dev/null +++ b/src/lerobot/policies/molmoact2/configuration_molmoact2.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import math +import os +from contextlib import suppress +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from huggingface_hub import snapshot_download + +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import ( + AdamWConfig, + CosineDecayWithWarmupSchedulerConfig, + LRSchedulerConfig, + OptimizerConfig, +) +from lerobot.utils.constants import ACTION, OBS_STATE + +from ..rtc.configuration_rtc import RTCConfig + +MOLMOACT2_DEFAULT_NUM_IMAGES = 2 +MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196 +MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80 +MOLMOACT2_TASK_TOKEN_BUDGET = 32 +MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32 +MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64 +MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4 +MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6 +MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95 + + +def _hf_token() -> str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_checkpoint_location( + checkpoint_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + checkpoint_path = str(checkpoint_path or "").strip() + if not checkpoint_path: + raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") + local_path = Path(checkpoint_path).expanduser() + if local_path.exists(): + return str(local_path) + return snapshot_download( + repo_id=checkpoint_path, + repo_type="model", + revision=revision, + force_download=force_download, + ignore_patterns=["*.py", "*.pyc", "__pycache__/*"], + token=_hf_token(), + ) + + +def _load_hf_norm_metadata_for_tag( + checkpoint_path: str, + *, + revision: str | None, + force_download: bool, + norm_tag: str | None, +) -> dict[str, Any]: + norm_tag = str(norm_tag or "").strip() + if not norm_tag: + return {} + checkpoint_location = Path( + _resolve_checkpoint_location( + checkpoint_path, + revision=revision, + force_download=force_download, + ) + ) + norm_stats_filename = "norm_stats.json" + config_path = checkpoint_location / "config.json" + if config_path.exists(): + with suppress(OSError, json.JSONDecodeError): + norm_stats_filename = str( + json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename + ) + stats_path = checkpoint_location / norm_stats_filename + if not stats_path.exists(): + raise FileNotFoundError( + f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." + ) + payload = json.loads(stats_path.read_text()) + metadata_by_tag = payload.get("metadata_by_tag") + if not isinstance(metadata_by_tag, dict): + raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") + metadata = metadata_by_tag.get(norm_tag) + if not isinstance(metadata, dict): + available = sorted(str(tag) for tag in metadata_by_tag) + raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") + return metadata + + +@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup") +@dataclass +class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig): + """MolmoAct2-local cosine scheduler with optional decay-step auto-match. + + LeRobot's generic cosine scheduler keeps an explicit integer decay length. + For MolmoAct2, leaving num_decay_steps unset means "decay across this run's + training steps"; build() is the first point where num_training_steps is known. + """ + + num_decay_steps: int | None + + def build(self, optimizer, num_training_steps: int): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.peak_lr, + decay_lr=self.decay_lr, + num_warmup_steps=self.num_warmup_steps, + num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps, + ).build(optimizer, num_training_steps=num_training_steps) + + +def _round_up(value: int, multiple: int) -> int: + return int(math.ceil(value / multiple) * multiple) + + +def infer_molmoact2_max_sequence_length( + *, + num_images: int, + state_dim: int, + action_dim: int, + action_horizon: int, + include_discrete_action: bool, +) -> int: + """Infer the padded text/image sequence cap from MolmoAct2's fixed token layout.""" + if num_images < 1: + num_images = MOLMOACT2_DEFAULT_NUM_IMAGES + if state_dim < 0: + state_dim = 0 + if action_dim < 1: + action_dim = 1 + if action_horizon < 1: + action_horizon = 1 + + image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE + prompt_tokens = ( + MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET + + MOLMOACT2_TASK_TOKEN_BUDGET + + state_dim + + MOLMOACT2_SEQUENCE_LENGTH_MARGIN + ) + action_tokens = 0 + if include_discrete_action: + action_tokens_per_step = max( + MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP, + math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM), + ) + action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step + + return _round_up( + image_tokens + prompt_tokens + action_tokens, + MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE, + ) + + +@PreTrainedConfig.register_subclass("molmoact2") +@dataclass +class MolmoAct2Config(PreTrainedConfig): + """MolmoAct2 policy backed by the converted HF checkpoint implementation.""" + + checkpoint_path: str = "allenai/MolmoAct2" + checkpoint_revision: str | None = None + checkpoint_force_download: bool = False + + n_obs_steps: int = 1 + chunk_size: int = 30 + n_action_steps: int = 30 + + action_mode: str = "both" + inference_action_mode: str | None = None + discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer" + discrete_generation_max_steps: int | None = None + norm_tag: str | None = None + + setup_type: str = "" + control_mode: str = "" + image_keys: list[str] = field(default_factory=list) + normalize_language: bool = True + add_setup_tokens: bool = True + add_control_tokens: bool = True + normalize_gripper: bool = False + num_state_tokens: int = 256 + # Leave unset for the default MolmoAct2 sequence budget inferred from the fixed + # image/prompt/state/action token layout. Override only for unusual long prompts. + max_sequence_length: int | None = None + + # Fixed by released MolmoAct2 checkpoints. We validate this at model load. + expected_max_action_dim: int = 32 + + # Flow-matching training knobs copied from the original MolmoAct2 training path. + num_flow_timesteps: int = 8 + flow_matching_cutoff: float = 1.0 + flow_matching_time_offset: float = 0.001 + flow_matching_time_scale: float = 0.999 + flow_matching_beta_alpha: float = 1.0 + flow_matching_beta_beta: float = 1.5 + num_inference_steps: int | None = None + mask_action_dim_padding: bool = True + enable_inference_cuda_graph: bool = True + # MolmoAct2-local eval option. When enabled, stochastic continuous action + # generation uses a rollout-local generator derived from eval_seed. + per_episode_seed: bool = False + eval_seed: int | None = None + rtc_config: RTCConfig | None = None + + # Default is full finetuning with gradients from the action expert flowing into the VLM. + enable_lora_vlm: bool = False + lora_rank: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_bias: str = "none" + enable_lora_action_expert: bool = False + enable_knowledge_insulation: bool = False + freeze_embedding: bool = True + train_action_expert_only: bool = False + gradient_checkpointing: bool = False + + model_dtype: str = "bfloat16" + softmax_auxiliary_loss: bool = True + softmax_auxiliary_loss_scale: float = 1e-4 + discrete_loss_token_weighting: str = "root_subsegments_root_tokens" + + optimizer_lr: float = 1e-5 + optimizer_vit_lr: float = 5e-6 + optimizer_connector_lr: float = 5e-6 + optimizer_action_expert_lr: float = 5e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-6 + optimizer_weight_decay: float = 0.0 + optimizer_grad_clip_norm: float = 1.0 + + scheduler_warmup_steps: int = 200 + scheduler_decay_steps: int | None = None + scheduler_decay_lr: float = 1e-6 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.QUANTILES, + "ACTION": NormalizationMode.QUANTILES, + } + ) + + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + dataset_feature_names: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + super().__post_init__() + if self.action_mode not in {"continuous", "discrete", "both"}: + raise ValueError( + f"Unsupported action_mode={self.action_mode!r}. " + "Expected one of {'continuous', 'discrete', 'both'}." + ) + if self.inference_action_mode not in {None, "continuous", "discrete"}: + raise ValueError( + f"Unsupported inference_action_mode={self.inference_action_mode!r}. " + "Expected one of {None, 'continuous', 'discrete'}." + ) + if self.inference_action_mode == "continuous" and self.action_mode == "discrete": + raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.") + if self.inference_action_mode == "discrete" and self.action_mode == "continuous": + raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.") + if self.train_action_expert_only and self.action_mode != "continuous": + raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.") + if self.train_action_expert_only and self.enable_lora_vlm: + raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.") + if self.enable_lora_action_expert and not self.enable_lora_vlm: + raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.") + if self.chunk_size < 1: + raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.") + if self.n_action_steps < 1: + raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.") + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})." + ) + if self.expected_max_action_dim != 32: + raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.") + if self.model_dtype not in {"float32", "bfloat16", "float16"}: + raise ValueError( + f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'." + ) + if self.lora_rank < 1: + raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.") + if self.lora_alpha < 1: + raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.") + if not 0 <= self.lora_dropout <= 1: + raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.") + if self.lora_bias not in {"none", "all", "lora_only"}: + raise ValueError( + f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'." + ) + if self.discrete_loss_token_weighting not in { + "none", + "token", + "root_tokens", + "root_subsegments", + "root_subsegments_root_tokens", + }: + raise ValueError( + f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}." + ) + if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1: + raise ValueError( + f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}." + ) + if self.max_sequence_length is not None and self.max_sequence_length < 1: + raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.") + + def inferred_max_sequence_length( + self, + *, + num_images: int | None = None, + state_dim: int | None = None, + action_dim: int | None = None, + action_horizon: int | None = None, + include_discrete_action: bool | None = None, + ) -> int: + if self.max_sequence_length is not None: + return int(self.max_sequence_length) + + if num_images is None: + num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES + if state_dim is None: + state_feature = self.robot_state_feature + state_dim = int(state_feature.shape[0]) if state_feature is not None else 0 + if action_dim is None: + action_feature = self.action_feature + action_dim = ( + int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim + ) + if action_horizon is None: + action_horizon = self.chunk_size + if include_discrete_action is None: + include_discrete_action = self.action_mode in {"discrete", "both"} + + return infer_molmoact2_max_sequence_length( + num_images=int(num_images), + state_dim=int(state_dim), + action_dim=int(action_dim), + action_horizon=int(action_horizon), + include_discrete_action=bool(include_discrete_action), + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None + + def get_optimizer_preset(self) -> OptimizerConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return MolmoAct2CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None: + self.dataset_feature_names = {} + for key in (ACTION, OBS_STATE): + feature = features.get(key) if isinstance(features, dict) else None + if isinstance(feature, dict) and feature.get("names") is not None: + self.dataset_feature_names[key] = feature["names"] + + def validate_features(self) -> None: + """Validate and set up MolmoAct2 input and output features.""" + image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL] + if not image_features: + raise ValueError( + "MolmoAct2 policy requires at least one visual input feature. " + "No features of type FeatureType.VISUAL found in input_features." + ) + + if OBS_STATE not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(0,), + ) + self.input_features[OBS_STATE] = state_feature + + if ACTION not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.expected_max_action_dim,), + ) + self.output_features[ACTION] = action_feature + + def apply_norm_tag_metadata(self) -> None: + if not str(self.norm_tag or "").strip(): + return + metadata = _load_hf_norm_metadata_for_tag( + self.checkpoint_path, + revision=self.checkpoint_revision, + force_download=bool(self.checkpoint_force_download), + norm_tag=self.norm_tag, + ) + if metadata.get("action_horizon") is not None: + self.chunk_size = int(metadata["action_horizon"]) + if metadata.get("n_action_steps") is not None: + self.n_action_steps = int(metadata["n_action_steps"]) + if not self.setup_type and metadata.get("setup_type") is not None: + self.setup_type = str(metadata["setup_type"]) + if not self.control_mode and metadata.get("control_mode") is not None: + self.control_mode = str(metadata["control_mode"]) + + def saved_policy_action_mode(self) -> str | None: + pretrained_path = getattr(self, "pretrained_path", None) + if pretrained_path is None: + return None + config_path = Path(pretrained_path) / "config.json" + if not config_path.exists(): + return None + try: + mode = json.loads(config_path.read_text()).get("action_mode") + except (OSError, json.JSONDecodeError): + return None + if mode in {"continuous", "discrete", "both"}: + return str(mode) + return None + + def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str: + return saved_policy_action_mode or self.action_mode + + def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None: + requested_mode = self.inference_action_mode + if requested_mode is None: + return + training_mode = self.training_action_mode(saved_policy_action_mode) + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run " + "continuous inference." + ) + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run " + "discrete inference. Train with action_mode='both' or action_mode='discrete' first." + ) + + def validate_checkpoint_action_mode( + self, + checkpoint_action_mode: str, + *, + has_action_expert: bool, + ) -> None: + if self.action_mode == "both" and checkpoint_action_mode != "both": + raise ValueError( + f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}." + ) + if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}: + raise ValueError( + f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, " + f"got {checkpoint_action_mode!r}." + ) + if self.action_mode in {"continuous", "both"} and not has_action_expert: + raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.") + + def resolve_inference_action_mode( + self, + requested_mode: str | None, + saved_policy_action_mode: str | None = None, + ) -> str: + training_mode = self.training_action_mode(saved_policy_action_mode) + if requested_mode is None: + requested_mode = self.inference_action_mode + if requested_mode is None: + raise ValueError( + "MolmoAct2 inference requires `inference_action_mode` to be set explicitly " + "to either 'continuous' or 'discrete'." + ) + if requested_mode not in {"continuous", "discrete"}: + raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.") + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.") + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.") + return requested_mode diff --git a/src/lerobot/policies/molmoact2/hf_model/__init__.py b/src/lerobot/policies/molmoact2/hf_model/__init__.py new file mode 100644 index 000000000..39b15cb3a --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa diff --git a/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py b/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py new file mode 100644 index 000000000..f7dacbce6 --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import logging +import os +from pathlib import Path +from typing import ClassVar + +import numpy as np +from tokenizers import ByteLevelBPETokenizer +from tokenizers.trainers import BpeTrainer +from huggingface_hub import snapshot_download +from transformers import PreTrainedTokenizerFast +from transformers.processing_utils import ProcessorMixin + + +def _hf_token() -> str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_tokenizer_location( + tokenizer_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + local_path = Path(str(tokenizer_path)).expanduser() + if local_path.exists(): + return str(local_path) + return snapshot_download( + repo_id=str(tokenizer_path), + repo_type="model", + revision=revision, + force_download=force_download, + ignore_patterns=["*.py", "*.pyc", "__pycache__/*"], + token=_hf_token(), + ) + + +class UniversalActionProcessor(ProcessorMixin): + attributes: ClassVar[list[str]] = ["tokenizer"] + tokenizer_class: str = "AutoTokenizer" + + def __init__( + self, + tokenizer: PreTrainedTokenizerFast, + scale: float = 10, + vocab_size: int = 1024, + min_token: int = 0, + *, + action_dim: int | None = None, + time_horizon: int | None = None, + ): + self.scale = scale + self.vocab_size = vocab_size + self.min_token = min_token + + # Action horizon and dimension needed during decoding. These can be specified + # in three ways (in order of priority): + # 1. passed in as kwargs to decode() + # 2. in the constructor + # 3. cached from the last time decode() was called + self.time_horizon = time_horizon + self.action_dim = action_dim + self.called_time_horizon = time_horizon + self.called_action_dim = action_dim + + super().__init__(tokenizer) + self.bpe_tokenizer = self.tokenizer + + def __call__(self, action_chunk: np.array) -> np.array: + from scipy.fft import dct + + assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]" + if action_chunk.ndim == 2: + action_chunk = action_chunk[None, ...] + + # Cache the time horizon and action dimension for decoding + self.called_time_horizon = action_chunk.shape[-2] + self.called_action_dim = action_chunk.shape[-1] + + dct_coeff = dct(action_chunk, axis=1, norm="ortho") + dct_coeff = np.around(dct_coeff * self.scale) + tokens = [] + for elem in dct_coeff: + token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int))) + tokens.append(self.bpe_tokenizer(token_str)["input_ids"]) + return tokens + + def decode( + self, + tokens: list[list[int]], + *, + time_horizon: int | None = None, + action_dim: int | None = None, + ) -> np.array: + from scipy.fft import idct + + self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon + self.action_dim = action_dim or self.action_dim or self.called_action_dim + + # Cache the time horizon and action dimension for the next call + self.called_time_horizon = self.time_horizon + self.called_action_dim = self.action_dim + + assert self.time_horizon is not None and self.action_dim is not None, ( + "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + ) + + decoded_actions = [] + for token in tokens: + try: + decoded_tokens = self.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) + assert decoded_dct_coeff.shape == ( + self.time_horizon, + self.action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + ) + except Exception as e: + print(f"Error decoding tokens: {e}") + print(f"Tokens: {token}") + decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) + decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho")) + return np.stack(decoded_actions) + + @classmethod + def fit( + cls, + action_data: list[np.array], + scale: float = 10, + vocab_size: int = 1024, + *, + time_horizon: int | None = None, + action_dim: int | None = None, + ) -> "UniversalActionProcessor": + from scipy.fft import dct + + # Run DCT over all inputs + dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data] + + # Quantize and find min token + max_token = int(np.around(np.concatenate(dct_tokens) * scale).max()) + min_token = int(np.around(np.concatenate(dct_tokens) * scale).min()) + min_vocab_size = max_token - min_token + + assert min_vocab_size <= vocab_size, ( + f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}" + ) + if min_vocab_size + 100 > vocab_size: + logging.warning( + f"Initial alphabet size {min_vocab_size} is almost as large as the vocab" + f"size {vocab_size}, consider increasing vocab size" + ) + + # Make token iterator for BPE training + def _token_iter(): + for tokens in dct_tokens: + rounded_tokens = np.around(tokens * scale) - min_token + rounded_tokens = rounded_tokens.astype(int) + string = "".join(map(chr, rounded_tokens)) + yield string + + # Train BPE tokenizer + bpe = ByteLevelBPETokenizer() + + # Set up the entire range of possible tokens as the initial alphabet + alphabet = [chr(i) for i in range(max_token - min_token + 1)] + trainer = BpeTrainer( + vocab_size=vocab_size, + min_frequency=2, + show_progress=True, + special_tokens=[], + initial_alphabet=alphabet, + max_token_length=10000, + ) + + # Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator() + # because it doesn't support custom alphabets) + bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) + + return cls( + PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False), + scale=scale, + vocab_size=vocab_size, + min_token=min_token, + time_horizon=time_horizon, + action_dim=action_dim, + ) + + @classmethod + def from_pretrained_local( + cls, + pretrained_model_name_or_path: str, + *, + revision: str | None = None, + force_download: bool = False, + ) -> "UniversalActionProcessor": + location = Path( + _resolve_tokenizer_location( + pretrained_model_name_or_path, + revision=revision, + force_download=force_download, + ) + ) + processor_config = {} + processor_config_path = location / "processor_config.json" + if processor_config_path.exists(): + import json + + processor_config = json.loads(processor_config_path.read_text()) + tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location)) + return cls( + tokenizer, + scale=processor_config.get("scale", 10), + vocab_size=processor_config.get("vocab_size", 1024), + min_token=processor_config.get("min_token", 0), + action_dim=processor_config.get("action_dim"), + time_horizon=processor_config.get("time_horizon"), + ) diff --git a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py new file mode 100644 index 000000000..29da68c14 --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +""" +MolmoAct2 configuration +""" + +from typing import Optional, Any + +from transformers import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class MolmoAct2VitConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`]. + It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments, + defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Example: + ```python + >>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer + + >>> # Initializing a MolmoAct2VitConfig + >>> configuration = MolmoAct2VitConfig() + + >>> # Initializing a MolmoAct2VisionTransformer (with random weights) + >>> model = MolmoAct2VisionTransformer(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "molmoact2" + base_config_key = "vit_config" + + def __init__( + self, + hidden_size: int = 1152, + intermediate_size: int = 4304, + num_hidden_layers: int = 27, + num_attention_heads: int = 16, + num_key_value_heads: int = 16, + head_dim: int = 72, + hidden_act: str = "gelu_pytorch_tanh", + layer_norm_eps: float = 1e-6, + image_default_input_size: tuple[int, int] = (378, 378), + image_patch_size: int = 14, + image_num_pos: int = 577, + attention_dropout: float = 0.0, + residual_dropout: float = 0.0, + initializer_range: float = 0.02, + float32_attention: bool = True, + attn_implementation: str = "eager", + **kwargs, + ): + self.attn_implementation = attn_implementation + super().__init__(attn_implementation=attn_implementation, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.image_default_input_size = image_default_input_size + self.image_patch_size = image_patch_size + self.image_num_pos = image_num_pos + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.initializer_range = initializer_range + self.float32_attention = float32_attention + + @property + def image_num_patch(self): + h, w = self.image_default_input_size + return h // self.image_patch_size, w // self.image_patch_size + + +class MolmoAct2AdapterConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig, + It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments, + defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Example: + + ```python + >>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone + + >>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig + >>> vit_config = MolmoAct2VitConfig() + >>> adapter_config = MolmoPoolingConfig() + + >>> # Initializing a MolmoAct2VisionBackbone (with random weights) + >>> model = MolmoAct2VisionBackbone(vit_config, adapter_config) + + >>> # Accessing the model configuration + >>> vit_configuration = model.vit_config + >>> adapter_configuration = model.adapter_config + ```""" + + model_type = "molmoact2" + base_config_key = "adapter_config" + + def __init__( + self, + vit_layers: tuple = (-3, -9), + pooling_attention_mask: bool = False, + hidden_size: int = 1152, + num_attention_heads: int = 16, + num_key_value_heads: int = 16, + head_dim: int = 72, + float32_attention: bool = True, + attention_dropout: float = 0.0, + residual_dropout: float = 0.0, + hidden_act: str = "silu", + intermediate_size: int = 18944, + text_hidden_size: int = 3584, + image_feature_dropout: float = 0.0, + initializer_range: float = 0.02, + attn_implementation: str = "eager", + **kwargs, + ): + self.attn_implementation = attn_implementation + super().__init__(attn_implementation=attn_implementation, **kwargs) + self.vit_layers = vit_layers + self.pooling_attention_mask = pooling_attention_mask + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.float32_attention = float32_attention + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.text_hidden_size = text_hidden_size + self.image_feature_dropout = image_feature_dropout + self.initializer_range = initializer_range + + +class MolmoAct2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a + `MolmoAct2TextModel` according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Example: + ```python + >>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel + + >>> # Initializing a MolmoAct2TextConfig + >>> configuration = MolmoAct2TextConfig() + + >>> # Initializing a MolmoAct2TextModel (with random weights) + >>> model = MolmoAct2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "molmoact2_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "blocks.*.self_attn.att_proj": "colwise", + "blocks.*.self_attn.attn_out": "rowwise", + "blocks.*.mlp.ff_proj": "colwise", + "blocks.*.mlp.ff_out": "rowwise", + } + base_model_pp_plan = { + "wte": (["input_ids"], ["inputs_embeds"]), + "blocks": (["hidden_states", "attention_mask"], ["hidden_states"]), + "ln_f": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + hidden_size: int = 3584, + num_attention_heads: int = 28, + num_key_value_heads: int | None = 4, + head_dim: int = 128, + vocab_size: int = 152064, + additional_vocab_size: int = 128, + qkv_bias: bool = True, + num_hidden_layers: int = 48, + intermediate_size: int = 18944, + hidden_act: str = "silu", + embedding_dropout: float = 0.0, + attention_dropout: float = 0.0, + residual_dropout: float = 0.0, + max_position_embeddings: int = 4096, + rope_theta: float = 1000000.0, + rope_scaling: dict[str, Any] = None, + rope_scaling_layers: list[int] | None = None, + use_qk_norm: bool = False, + qk_norm_type: str = "olmo", + layer_norm_eps: int = 1e-6, + norm_after: bool = False, + initializer_range: float = 0.02, + use_cache=True, + tie_word_embeddings=False, + attn_implementation: str = "eager", + **kwargs, + ): + self.attn_implementation = attn_implementation + super().__init__( + tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs + ) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.vocab_size = vocab_size + self.additional_vocab_size = additional_vocab_size + self.qkv_bias = qkv_bias + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.embedding_dropout = embedding_dropout + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.rope_scaling_layers = rope_scaling_layers + self.use_qk_norm = use_qk_norm + self.qk_norm_type = qk_norm_type + self.layer_norm_eps = layer_norm_eps + self.norm_after = norm_after + self.initializer_range = initializer_range + self.use_cache = use_cache + + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) + + +class MolmoAct2ActionExpertConfig(PretrainedConfig): + r"""Configuration for the MolmoAct2 modern action expert.""" + + model_type = "molmoact2_action_expert" + base_config_key = "action_expert_config" + + def __init__( + self, + max_action_horizon: int = 32, + max_action_dim: int = 32, + hidden_size: int = 1024, + num_layers: int = 32, + num_heads: int = 16, + mlp_ratio: float = 8.0 / 3.0, + ffn_multiple_of: int = 256, + timestep_embed_dim: int = 256, + dropout: float = 0.0, + attn_dropout: float = 0.0, + context_layer_norm: bool = True, + qk_norm: bool = True, + qk_norm_eps: float = 1e-6, + rope: bool = True, + causal_attn: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.max_action_horizon = max_action_horizon + self.max_action_dim = max_action_dim + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.ffn_multiple_of = ffn_multiple_of + self.timestep_embed_dim = timestep_embed_dim + self.dropout = dropout + self.attn_dropout = attn_dropout + self.context_layer_norm = context_layer_norm + self.qk_norm = qk_norm + self.qk_norm_eps = qk_norm_eps + self.rope = rope + self.causal_attn = causal_attn + + def to_dict(self): + output = super().to_dict() + # These are derived from the parent MolmoAct2Config for HF exports. Keeping + # them out of the public nested config avoids duplicated sources of truth. + output.pop("max_action_horizon", None) + output.pop("max_action_dim", None) + return output + + +class MolmoAct2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`]. + It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture. + + Example: + + ```python + >>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig + + >>> # Initializing a MolmoAct2VitConfig + >>> vit_config = MolmoAct2VitConfig() + + >>> # Initializing a MolmoAct2AdapterConfig + >>> adapter_config = MolmoAct2AdapterConfig() + + >>> # Initializing a MolmoAct2TextConfig + >>> text_config = MolmoAct2TextConfig() + + >>> # Initializing a MolmoAct2Config + >>> configuration = MolmoAct2Config( + >>> vit_config=vit_config, + >>> adapter_config=adapter_config, + >>> text_config=text_config, + >>> image_start_token_id=151936, + >>> image_end_token_id=151937, + >>> image_patch_id=151938, + >>> image_col_id=151939, + >>> low_res_image_start_token_id=151940, + >>> image_low_res_id=151942, + >>> frame_start_token_id=151943, + >>> frame_end_token_id=151944, + >>> ) + + >>> # Initializing a model + >>> model = MolmoAct2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "molmoact2" + sub_configs = { + "text_config": MolmoAct2TextConfig, + "vit_config": MolmoAct2VitConfig, + "adapter_config": MolmoAct2AdapterConfig, + "action_expert_config": MolmoAct2ActionExpertConfig, + } + + def __init__( + self, + vit_config: MolmoAct2VitConfig = None, + adapter_config: MolmoAct2AdapterConfig = None, + text_config: MolmoAct2TextConfig = None, + action_expert_config: MolmoAct2ActionExpertConfig = None, + image_start_token_id: int = None, + low_res_image_start_token_id: int = None, + image_end_token_id: int = None, + image_low_res_id: int = None, + image_patch_id: int = None, + image_col_id: int = None, + frame_start_token_id: int = None, + frame_end_token_id: int = None, + use_frame_special_tokens: bool = True, + initializer_range: float = 0.02, + add_action_expert: bool = True, + max_action_dim: int = 32, + max_action_horizon: int = 30, + n_obs_steps: int = 30, + action_mode: str = "both", + state_format: str = "discrete", + flow_matching_num_steps: int = 10, + flow_matching_cutoff: float = 1.0, + flow_matching_time_offset: float = 0.001, + flow_matching_time_scale: float = 0.999, + flow_matching_beta_alpha: float = 1.0, + flow_matching_beta_beta: float = 1.5, + mask_action_dim_padding: bool = True, + enable_depth_reasoning: bool = False, + depth_mode: int = 2, + num_depth_codes: int = 100, + action_expert_depth_gate: bool = False, + action_expert_depth_gate_per_layer: bool = False, + action_expert_depth_gate_init_bias: float = -4.0, + action_output_token_id: int = None, + action_start_token_id: int = None, + action_end_token_id: int = None, + action_token_start_id: int = None, + num_action_tokens: int = 0, + depth_output_token_id: int = None, + depth_start_token_id: int = None, + depth_end_token_id: int = None, + depth_token_start_id: int = None, + num_depth_tokens: int = 0, + state_start_token_id: int = None, + state_end_token_id: int = None, + state_token_start_id: int = None, + num_state_tokens: int = 0, + add_setup_tokens: bool = True, + add_control_tokens: bool = True, + norm_stats_filename: str = "norm_stats.json", + **kwargs, + ): + super().__init__(**kwargs) + if vit_config is None: + self.vit_config = MolmoAct2VitConfig() + elif isinstance(vit_config, dict): + self.vit_config = MolmoAct2VitConfig(**vit_config) + else: + self.vit_config = vit_config + if adapter_config is None: + self.adapter_config = MolmoAct2AdapterConfig() + elif isinstance(adapter_config, dict): + self.adapter_config = MolmoAct2AdapterConfig(**adapter_config) + else: + self.adapter_config = adapter_config + if text_config is None: + self.text_config = MolmoAct2TextConfig() + elif isinstance(text_config, dict): + self.text_config = MolmoAct2TextConfig(**text_config) + else: + self.text_config = text_config + self.add_action_expert = bool(add_action_expert) + if not self.add_action_expert: + self.action_expert_config = None + elif action_expert_config is None: + self.action_expert_config = MolmoAct2ActionExpertConfig( + max_action_horizon=max_action_horizon, + max_action_dim=max_action_dim, + num_layers=self.text_config.num_hidden_layers, + ) + elif isinstance(action_expert_config, dict): + self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config) + else: + self.action_expert_config = action_expert_config + if self.add_action_expert: + self.action_expert_config.max_action_dim = int(max_action_dim) + self.action_expert_config.max_action_horizon = int(max_action_horizon) + self._validate_release_action_config( + state_format=state_format, + ) + self.image_start_token_id = image_start_token_id + self.low_res_image_start_token_id = low_res_image_start_token_id + self.image_end_token_id = image_end_token_id + self.image_low_res_id = image_low_res_id + self.image_high_res_id = image_patch_id + self.image_patch_id = image_patch_id + self.image_col_id = image_col_id + self.frame_start_token_id = frame_start_token_id + self.frame_end_token_id = frame_end_token_id + self.use_frame_special_tokens = use_frame_special_tokens + self.initializer_range = initializer_range + self.max_action_dim = max_action_dim + self.max_action_horizon = max_action_horizon + self.n_obs_steps = n_obs_steps + self.action_mode = action_mode + self.state_format = state_format + self.flow_matching_num_steps = flow_matching_num_steps + self.flow_matching_cutoff = flow_matching_cutoff + self.flow_matching_time_offset = flow_matching_time_offset + self.flow_matching_time_scale = flow_matching_time_scale + self.flow_matching_beta_alpha = flow_matching_beta_alpha + self.flow_matching_beta_beta = flow_matching_beta_beta + self.mask_action_dim_padding = mask_action_dim_padding + self.enable_depth_reasoning = enable_depth_reasoning + self.depth_mode = depth_mode + self.num_depth_codes = num_depth_codes + self.action_expert_depth_gate = action_expert_depth_gate + self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer + self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias + self.action_output_token_id = action_output_token_id + self.action_start_token_id = action_start_token_id + self.action_end_token_id = action_end_token_id + self.action_token_start_id = action_token_start_id + self.num_action_tokens = num_action_tokens + self.depth_output_token_id = depth_output_token_id + self.depth_start_token_id = depth_start_token_id + self.depth_end_token_id = depth_end_token_id + self.depth_token_start_id = depth_token_start_id + self.num_depth_tokens = num_depth_tokens + self.state_start_token_id = state_start_token_id + self.state_end_token_id = state_end_token_id + self.state_token_start_id = state_token_start_id + self.num_state_tokens = num_state_tokens + self.add_setup_tokens = add_setup_tokens + self.add_control_tokens = add_control_tokens + self.norm_stats_filename = norm_stats_filename + + @staticmethod + def _validate_release_action_config( + *, + state_format: str, + ) -> None: + if state_format != "discrete": + raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.") + + @property + def image_num_patch(self): + assert self.vit_config is not None + return self.vit_config.image_num_patch + + @property + def num_attention_heads(self): + return self.text_config.num_attention_heads + + @property + def num_key_value_heads(self): + return self.text_config.num_key_value_heads + + @property + def head_dim(self): + return self.text_config.head_dim + + @property + def num_hidden_layers(self): + return self.text_config.num_hidden_layers + + @property + def hidden_size(self): + return self.text_config.hidden_size + + @property + def vocab_size(self): + return self.text_config.vocab_size + + @property + def max_position_embeddings(self): + return self.text_config.max_position_embeddings + + +MolmoAct2VitConfig.register_for_auto_class() +MolmoAct2AdapterConfig.register_for_auto_class() +MolmoAct2TextConfig.register_for_auto_class() +MolmoAct2ActionExpertConfig.register_for_auto_class() +MolmoAct2Config.register_for_auto_class() diff --git a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py new file mode 100644 index 000000000..a172c8477 --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py @@ -0,0 +1,564 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +"""Image processor class for MolmoAct2""" + +from typing import Optional, Union +import numpy as np +import einops +import torch +import torchvision.transforms + +from transformers.image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + make_flat_list_of_images, + valid_images, + to_numpy_array, +) +from transformers.image_transforms import convert_to_rgb +from transformers.processing_utils import ImagesKwargs +from transformers.image_processing_utils import BaseImageProcessor, get_size_dict +from transformers.utils import logging +from transformers.feature_extraction_utils import BatchFeature +from transformers.utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +def normalize_image( + image: np.ndarray, + image_mean: list[float], + image_std: list[float], +) -> np.ndarray: + if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]): + return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32) + image -= np.array(image_mean, dtype=np.float32)[None, None, :] + image /= np.array(image_std, dtype=np.float32)[None, None, :] + return image + + +def resize_image( + image: np.ndarray, + desired_output_size: list[int], + resample: PILImageResampling, +) -> np.ndarray: + image = torch.permute(torch.from_numpy(image), [2, 0, 1]) + dtype = image.dtype + if torch.is_floating_point(image): + in_min = 0.0 + in_max = 1.0 + resized = torchvision.transforms.Resize( + desired_output_size, + resample, + antialias=False, + )(image) + resized = torch.clip(resized, 0.0, 1.0).to(dtype) + else: + assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format( + image.dtype + ) + in_min = 0.0 + in_max = 255.0 + resized = torchvision.transforms.Resize( + desired_output_size, + resample, + antialias=False, + )(image) + resized = torch.clip(resized, 0, 255).to(dtype) + + resized = resized.to(torch.float32) + resized = (resized - in_min) / (in_max - in_min) + + resized = torch.permute(resized, [1, 2, 0]).numpy() + + return resized + + +def select_tiling(h, w, patch_size, max_num_crops): + """Divide in image of size [w, h] in up to max_num_patches of size patch_size""" + original_size = np.stack([h, w]) # [1, 2] + original_res = h * w + tilings = [] + for i in range(1, max_num_crops + 1): + for j in range(1, max_num_crops + 1): + if i * j <= max_num_crops: + tilings.append((i, j)) + # sort so argmin and argmax favour smaller tilings in the event of a tie + tilings.sort(key=lambda x: (x[0] * x[1], x[0])) + candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2] + candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2] + + # How much we would need to scale the image to fit exactly in each tiling + original_size = np.stack([h, w], dtype=np.float32) # [1, 2] + + # The original size can be zero in rare cases if the image is smaller than the margin + # In those cases letting the scale become infinite means the tiling is based on the + # other side, or falls back to the smallest tiling + with np.errstate(divide="ignore"): + required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,) + required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1] + if np.all(required_scale < 1): + # We are forced to downscale, so try to minimize the amount of downscaling + ix = np.argmax(required_scale) + else: + # Pick the resolution that required the least upscaling so that it most closely fits the image + required_scale = np.where(required_scale < 1.0, 10e9, required_scale) + ix = np.argmin(required_scale) + return candidate_tilings[ix] + + +def build_resized_image( + image: np.ndarray, + base_image_input_size: list[int], + resample: PILImageResampling, + image_mean: list[float], + image_std: list[float], + image_patch_size: int, +) -> tuple[np.ndarray, np.ndarray]: + resized = resize_image( + image, + base_image_input_size, + resample, + ) + resized = normalize_image(resized, image_mean, image_std) + if len(resized.shape) == 3: + resized = np.expand_dims(resized, 0) + crop_patch_w = base_image_input_size[1] // image_patch_size + crop_patch_h = base_image_input_size[0] // image_patch_size + resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w]) + return resized, resize_idx + + +def build_overlapping_crops( + image: np.ndarray, + max_crops: int, + overlap_margins: list[int], + base_image_input_size: list[int], + resample: PILImageResampling, + image_mean: list[float], + image_std: list[float], + image_patch_size: int, +) -> tuple[np.ndarray, np.ndarray]: + """Decompose an image into a set of overlapping crops + + :return crop_arr: [n_crops, h, w, 3] The crops + :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image + the crops were extracted from, what patch in `crop_arr` it corresponds to + """ + original_image_h, original_image_w = image.shape[:2] + crop_size = base_image_input_size[0] + assert base_image_input_size[0] == base_image_input_size[1] + + left_margin, right_margin = overlap_margins + total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim + crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim + crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches + crop_window_size = crop_window_patches * image_patch_size + crop_patch_w = base_image_input_size[1] // image_patch_size + crop_patch_h = base_image_input_size[0] // image_patch_size + original_image_h, original_image_w = image.shape[:2] + crop_size = base_image_input_size[0] + + # Decide how to tile the image, to account for the overlap margins we compute the tiling + # as if we had an image without the margins and were using a crop size without the margins + tiling = select_tiling( + original_image_h - total_margin_pixels, + original_image_w - total_margin_pixels, + crop_window_size, + max_crops, + ) + + src = resize_image( + image, + [ + tiling[0] * crop_window_size + total_margin_pixels, + tiling[1] * crop_window_size + total_margin_pixels, + ], + resample, + ) + src = normalize_image(src, image_mean, image_std) + + # Now we have to split the image into crops, and track what patches came from + # where in `patch_idx_arr` + n_crops = tiling[0] * tiling[1] + crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype) + patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32) + on_crop = 0 + for i in range(tiling[0]): + # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size` + # which results in overlapping crop windows + y0 = i * crop_window_size + for j in range(tiling[1]): + x0 = j * crop_window_size + crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size] + patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w) + patch_idx += on_crop * crop_patch_h * crop_patch_w + + # Mask out idx that are in the overlap region + if i != 0: + patch_idx[:left_margin, :] = -1 + if j != 0: + patch_idx[:, :left_margin] = -1 + if i != tiling[0] - 1: + patch_idx[-right_margin:, :] = -1 + if j != tiling[1] - 1: + patch_idx[:, -right_margin:] = -1 + patch_idx_arr[on_crop] = patch_idx + on_crop += 1 + + # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr` + # so it is ordered left-to-right order + patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w]) + patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3]) + patch_idx_arr = np.reshape(patch_idx_arr, [-1]) + + # Now get the parts not in the overlap region, so it should map each patch in `src` + # to the correct patch it should come from in `crop_arr` + patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape( + src.shape[0] // image_patch_size, + src.shape[1] // image_patch_size, + ) + return crop_arr, patch_idx_arr + + +def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray: + """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]""" + if len(array.shape) == 3: + n_crops, h, w = array.shape + h_patches = h // patch_size + w_patches = w // patch_size + array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size]) + array = np.transpose(array, [0, 1, 3, 2, 4]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size]) + return array + else: + n_crops, h, w, c = array.shape + h_patches = h // patch_size + w_patches = w // patch_size + array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c]) + array = np.transpose(array, [0, 1, 3, 2, 4, 5]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c]) + return array + + +def arange_for_pooling( + idx_arr: np.ndarray, + pool_h: int, + pool_w: int, +) -> np.ndarray: + h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0] + w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1] + idx_arr = np.pad( + idx_arr, + [[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]], + mode="constant", + constant_values=-1, + ) + return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) + + +def image_to_patches_and_grids( + image: np.ndarray, + max_crops: int, + overlap_margins: list[int], + base_image_input_size: list[int], + resample: PILImageResampling, + image_mean: list[float], + image_std: list[float], + image_patch_size: int, + image_pooling_w: int, + image_pooling_h: int, + crop_mode: str = "overlap-and-resize-c2", +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + :return image_grids, the shape of each (low-res, high-res) image after pooling + :return crops, the image crops to processes with the ViT + :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the + patches in `crops` to pool for that token, masked with -1 + """ + if isinstance(base_image_input_size, int): + base_image_input_size = (base_image_input_size, base_image_input_size) + + base_image_input_d = image_patch_size + pooling_w = image_pooling_w + pooling_h = image_pooling_h + crop_patch_w = base_image_input_size[1] // base_image_input_d + crop_patch_h = base_image_input_size[0] // base_image_input_d + + if crop_mode == "resize": + resized, resize_idx = build_resized_image( + image, + base_image_input_size, + resample, + image_mean, + image_std, + image_patch_size, + ) + resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) + resized_h, resized_w = resize_idx.shape[:2] + resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w]) + image_grid = [np.array([resized_h, resized_w, 0, 0])] + return ( + np.stack(image_grid, 0), + batch_pixels_to_patches(resized, image_patch_size), + resize_idx, + ) + + if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}: + raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.") + + crop_arr, patch_idx_arr = build_overlapping_crops( + image, + max_crops, + overlap_margins, + base_image_input_size, + resample, + image_mean, + image_std, + image_patch_size, + ) + pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w) + h, w = pooling_idx.shape[:2] + pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w]) + + # Finally do the same for the global image + resized, resize_idx = build_resized_image( + image, + base_image_input_size, + resample, + image_mean, + image_std, + image_patch_size, + ) + crop_arr = np.concatenate([resized, crop_arr], 0) + + resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) + resized_h, resized_w = resize_idx.shape[:2] + resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w]) + + # Global image goes first, so the order of patches in previous crops gets increased + pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1) + pooling_idx = np.concatenate([resize_idx, pooling_idx]) + image_grid = [np.array([resized_h, resized_w, h, w])] + + return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx) + + +class MolmoAct2ImagesKwargs(ImagesKwargs, total=False): + max_crops: int | None + overlap_margins: list[int] | None + crop_mode: str | None + patch_size: int | None + pooling_size: list[int] | None + + +class MolmoAct2ImageProcessor(BaseImageProcessor): + r""" + Constructs a MolmoAct2 image processor that preprocesses images for the model. + + Args: + size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`): + Size of the image after resizing. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use when resizing the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + max_crops (`int`, *optional*, defaults to `8`): + Maximum number of crops to use per image. + overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`): + Overlap margins to use. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`): + The pooling size of the vision adapter. + """ + + model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"] + + def __init__( + self, + size: dict[str, int] | None = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool = True, + max_crops: int = 8, + overlap_margins: list[int] = [4, 4], + crop_mode: str = "overlap-and-resize-c2", + patch_size: int = 14, + pooling_size: list[int] = [2, 2], + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 378, "width": 378} + size = get_size_dict(size, default_to_square=True) + self.size = size + + self.resample = resample + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_convert_rgb = do_convert_rgb + + self.max_crops = max_crops + self.overlap_margins = overlap_margins + self.crop_mode = crop_mode + self.patch_size = patch_size + self.pooling_size = pooling_size + + def preprocess( + self, + images: ImageInput, + size: dict[str, int] | None = None, + resample: PILImageResampling | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool | None = None, + max_crops: int | None = None, + overlap_margins: list[int] | None = None, + crop_mode: str | None = None, + patch_size: int | None = None, + pooling_size: list[int] | None = None, + return_tensors: str | TensorType | None = None, + **kwargs, + ) -> BatchFeature: + """ + Args: + images (`ImageInput`): + Image to preprocess. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + max_crops (`int`, *optional*, defaults to `self.max_crops`): + Maximum number of crops to use per image. + overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`): + Overlap margins to use. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`): + The pooling size of the vision adapter. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + + Returns: + A `BatchFeature` containing the following keys: + - `pixel_values`: The preprocessed images. + - `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`. + - `image_grids`: The image grids. + - `image_num_crops`: The number of crops for each image. + """ + if size is not None: + if "height" not in size or "width" not in size: + raise ValueError("size must contain 'height' and 'width' keys.") + else: + size = {**self.size} + + base_image_input_size = [size["height"], size["width"]] + + resample = resample or self.resample + image_mean = image_mean or self.image_mean + image_std = image_std or self.image_std + do_convert_rgb = do_convert_rgb or self.do_convert_rgb + + max_crops = max_crops or self.max_crops + overlap_margins = overlap_margins or self.overlap_margins + crop_mode = crop_mode or self.crop_mode + patch_size = patch_size or self.patch_size + pooling_size = pooling_size or self.pooling_size + + image_pooling_h, image_pooling_w = pooling_size + + if images is not None: + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + data = {} + if images is not None: + batch_grids = [] + batch_crops = [] + batch_pooled_patches_idx = [] + batch_num_crops = [] + + for image in images: + image_grid, crops, pooled_idx = image_to_patches_and_grids( + image, + max_crops, + overlap_margins, + base_image_input_size, + resample, + image_mean, + image_std, + patch_size, + image_pooling_w, + image_pooling_h, + crop_mode, + ) + batch_grids.append(image_grid) + batch_crops.append(crops) + batch_pooled_patches_idx.append(pooled_idx) + batch_num_crops.append(crops.shape[0]) + + pixel_values = np.concatenate(batch_crops, 0) + image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0) + image_grids = np.concatenate(batch_grids, 0) + image_num_crops = np.array(batch_num_crops) + + data.update( + pixel_values=pixel_values, + image_token_pooling=image_token_pooling, + image_grids=image_grids, + image_num_crops=image_num_crops, + ) + + return BatchFeature(data, tensor_type=return_tensors) + + +MolmoAct2ImageProcessor.register_for_auto_class() diff --git a/src/lerobot/policies/molmoact2/hf_model/inference.py b/src/lerobot/policies/molmoact2/hf_model/inference.py new file mode 100644 index 000000000..2c0243880 --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/inference.py @@ -0,0 +1,748 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +"""Inference utilities for MolmoAct2""" + +from dataclasses import dataclass +from typing import Any, Optional, Tuple +from collections.abc import Iterable, Sequence + +import torch +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.configuration_utils import PretrainedConfig + + +@dataclass +class _ActionFlowInputs: + trajectory: torch.Tensor + context: Any + modulations: Sequence[Any] + action_dim_is_pad: torch.Tensor | None + + +@dataclass +class _ActionFlowCudaGraph: + key: tuple[Any, ...] + graph: torch.cuda.CUDAGraph + static_inputs: _ActionFlowInputs + output: torch.Tensor + + +@dataclass +class _DepthDecodeCudaGraphLayerStage: + residual: torch.Tensor + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + + +@dataclass +class _DepthDecodeCudaGraphPostStage: + graph: torch.cuda.CUDAGraph + attn_context: torch.Tensor + + +@dataclass +class _DepthDecodeCudaGraph: + cache_key: tuple[Any, ...] + pre_graph: torch.cuda.CUDAGraph + token_ids: torch.Tensor + cos: torch.Tensor + sin: torch.Tensor + positions: torch.Tensor + stages: Sequence[_DepthDecodeCudaGraphLayerStage] + post_graphs: Sequence[_DepthDecodeCudaGraphPostStage] + output: torch.Tensor + + +@dataclass +class _DepthDecodeCudaGraphSpec: + eligible: bool + cache_key_prefix: tuple[Any, ...] + num_hidden_layers: int + head_dim: int + num_attention_heads: int + + +def _cache_seq_len_int(past_key_values: Cache | None) -> int: + if past_key_values is None: + return 0 + seq_len = past_key_values.get_seq_length() + if torch.is_tensor(seq_len): + return int(seq_len.item()) + return int(seq_len) + + +def _cache_max_len_int(past_key_values: Cache | None) -> int: + if past_key_values is None: + return -1 + max_len = past_key_values.get_max_cache_shape() + if torch.is_tensor(max_len): + return int(max_len.item()) + return int(max_len) + + +def _iter_cache_key_values( + past_key_values: Cache, +) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]: + layers = getattr(past_key_values, "layers", None) + if layers is not None: + for layer in layers: + yield getattr(layer, "keys", None), getattr(layer, "values", None) + return + for layer in past_key_values: + yield layer[0], layer[1] + + +class _DepthDecodeStaticLayerCache: + is_compileable = False + is_sliding = False + + def __init__(self, max_cache_len: int) -> None: + self.max_cache_len = int(max_cache_len) + self.cumulative_length = 0 + self.keys: torch.Tensor | None = None + self.values: torch.Tensor | None = None + + def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: + bsz, n_heads = key_states.shape[:2] + self.keys = torch.empty( + (bsz, n_heads, self.max_cache_len, key_states.shape[-1]), + dtype=key_states.dtype, + device=key_states.device, + ) + self.values = torch.empty( + (bsz, n_heads, self.max_cache_len, value_states.shape[-1]), + dtype=value_states.dtype, + device=value_states.device, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + *args, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.keys is None: + self._allocate(key_states, value_states) + start = self.cumulative_length + end = start + key_states.shape[-2] + if end > self.max_cache_len: + raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.") + self.keys[:, :, start:end, :].copy_(key_states) + self.values[:, :, start:end, :].copy_(value_states) + self.cumulative_length = end + return self.keys[:, :, :end, :], self.values[:, :, :end, :] + + def get_seq_length(self) -> int: + return self.cumulative_length + + def get_max_cache_shape(self) -> int: + return -1 + + def reset(self) -> None: + self.cumulative_length = 0 + + +class _DepthDecodeStaticCache(Cache): + def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None: + text_config = config.get_text_config(decoder=True) + super().__init__( + layers=[ + _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len) + for _ in range(text_config.num_hidden_layers) + ] + ) + + def get_seq_length(self, layer_idx: int = 0) -> int: + return self.layers[layer_idx].get_seq_length() + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + return self.layers[layer_idx].get_max_cache_shape() + + def reset(self) -> None: + for layer in self.layers: + layer.reset() + + +class ActionCudaGraphManager: + def __init__(self, model: Any) -> None: + self.model = model + self.enabled = True + self.action_flow_graph: _ActionFlowCudaGraph | None = None + + def set_enabled(self, enabled: bool) -> None: + self.enabled = bool(enabled) + + def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool: + action_model = self.model + if not self.enabled: + return False + if action_model.training or action_model._require_action_expert().training: + return False + if inputs.trajectory.device.type != "cuda": + return False + + def all_on_cuda(): + yield inputs.trajectory + for k, v in inputs.context.kv_contexts: + yield k + yield v + for t in ( + inputs.context.cross_mask, + inputs.context.self_mask, + inputs.context.valid_action, + inputs.action_dim_is_pad, + ): + if t is not None: + yield t + if inputs.context.rope_cache is not None: + yield from inputs.context.rope_cache + for step in inputs.modulations: + yield step.conditioning + for block_modulation in step.block_modulations: + yield from block_modulation + yield from step.final_modulation + + return all(t.device.type == "cuda" for t in all_on_cuda()) + + def run_action_flow( + self, + inputs: _ActionFlowInputs, + steps: int, + run_loop, + ) -> torch.Tensor: + key = _cuda_graph_key(inputs, steps) + cache = self.action_flow_graph + if cache is None or cache.key != key: + static_inputs = _clone_static_inputs(inputs) + graph, output = _capture_cuda_graph( + lambda: run_loop(static_inputs, steps), + inputs.trajectory.device, + after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory), + ) + cache = _ActionFlowCudaGraph( + key=key, + graph=graph, + static_inputs=static_inputs, + output=output, + ) + self.action_flow_graph = cache + else: + _copy_inputs_(cache.static_inputs, inputs) + + cache.graph.replay() + return cache.output.clone() + + +class DepthDecodeCudaGraphManager: + def __init__(self, model: Any) -> None: + self.model = model + self.backbone = model.model + self.enabled = True + self.graph: _DepthDecodeCudaGraph | None = None + self.graph_spec: _DepthDecodeCudaGraphSpec | None = None + + def set_enabled(self, enabled: bool) -> None: + self.enabled = bool(enabled) + + def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache: + return _DepthDecodeStaticCache( + config=self.model.config.text_config, + max_cache_len=max_cache_len, + ) + + def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec: + static = self.graph_spec + if static is None: + cfg = self.backbone.transformer.config + rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None) + static = _DepthDecodeCudaGraphSpec( + eligible=( + not cfg.norm_after + and cfg.rope_scaling_layers is None + and getattr(rotary_emb, "rope_type", None) == "default" + and cfg._attn_implementation == "sdpa" + ), + cache_key_prefix=( + cfg.hidden_size, + cfg.num_attention_heads, + cfg.num_key_value_heads, + cfg.head_dim, + cfg.num_hidden_layers, + cfg.use_qk_norm, + cfg.qk_norm_type, + cfg._attn_implementation, + ), + num_hidden_layers=cfg.num_hidden_layers, + head_dim=cfg.head_dim, + num_attention_heads=cfg.num_attention_heads, + ) + self.graph_spec = static + return static + + def can_use( + self, + next_input_ids: torch.Tensor, + *, + past_key_values: Cache, + attention_bias: torch.Tensor, + ) -> bool: + if not self.enabled or self.model.training or self.backbone.transformer.training: + return False + if next_input_ids.device.type != "cuda": + return False + if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1: + return False + if not isinstance(past_key_values, _DepthDecodeStaticCache): + return False + if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device: + return False + return self._depth_decode_spec().eligible + + def _depth_decode_key( + self, + next_input_ids: torch.Tensor, + attention_bias: torch.Tensor, + ) -> tuple[Any, ...]: + device = next_input_ids.device + return ( + self._depth_decode_spec().cache_key_prefix, + device.type, + device.index, + self.model.lm_head.weight.dtype, + attention_bias.shape[-1], + ) + + def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None: + emb = self.backbone.transformer.rotary_emb + cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :]) + sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :]) + + def _depth_decode_pre_layer( + self, + layer_idx: int, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + block = self.backbone.transformer.blocks[layer_idx] + attention = block.self_attn + residual = hidden_states + hidden_states = block.attn_norm(hidden_states) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, attention.head_dim) + qkv = attention.att_proj(hidden_states) + query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1) + value_states = value_states.view(hidden_shape) + + apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None + norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3" + + if apply_qk_norm and not norm_after_view: + query_states = attention.q_norm(query_states) + key_states = attention.k_norm(key_states) + + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + + if norm_after_view: + query_states = attention.q_norm(query_states) + key_states = attention.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin) + return residual, query_states, key_states, value_states + + def _depth_decode_pre0( + self, + token_ids: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + inputs_embeds = self.model._embed_base_tokens(token_ids) + return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin) + + def _depth_decode_post_layer( + self, + layer_idx: int, + residual: torch.Tensor, + attn_context: torch.Tensor, + ) -> torch.Tensor: + block = self.backbone.transformer.blocks[layer_idx] + attention = block.self_attn + input_shape = residual.shape[:-1] + attn_output = attn_context.reshape(*input_shape, -1).contiguous() + attn_output = attention.attn_out(attn_output) + hidden_states = residual + block.dropout(attn_output) + + residual = hidden_states + hidden_states = block.ff_norm(hidden_states) + hidden_states = block.mlp(hidden_states) + hidden_states = residual + block.dropout(hidden_states) + return hidden_states + + def _depth_decode_post_and_pre_next( + self, + layer_idx: int, + residual: torch.Tensor, + attn_context: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context) + return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin) + + def _depth_decode_last_post( + self, + layer_idx: int, + residual: torch.Tensor, + attn_context: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context) + return self.backbone.transformer.ln_f(hidden_states) + + def _build_depth_decode_graph( + self, + next_input_ids: torch.Tensor, + *, + past_length: int, + attention_bias: torch.Tensor, + ) -> _DepthDecodeCudaGraph: + text_config = self.backbone.transformer.config + device = next_input_ids.device + dtype = self.model.lm_head.weight.dtype + static = self._depth_decode_spec() + num_layers = static.num_hidden_layers + head_dim = static.head_dim + max_cache_len = int(attention_bias.shape[-1]) + max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len) + self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len) + + token_ids = torch.empty((1, 1), device=device, dtype=torch.long) + cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype) + sin = torch.empty_like(cos) + positions = torch.arange(max_cache_len, device=device, dtype=torch.long) + context_shape = (1, 1, static.num_attention_heads, head_dim) + + token_ids.copy_(next_input_ids) + self._select_depth_decode_rope(cos, sin, past_length=past_length) + + pre_graph, pre_output = _capture_cuda_graph( + lambda: self._depth_decode_pre0(token_ids, cos, sin), + device, + ) + stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)] + post_graphs = [] + for layer_idx in range(num_layers - 1): + stage = stages[-1] + attn_context = torch.empty(context_shape, device=device, dtype=dtype) + graph, output = _capture_cuda_graph( + lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: ( + self._depth_decode_post_and_pre_next( + layer_idx, + stage.residual, + attn_context, + cos, + sin, + ) + ), + device, + ) + post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)) + stages.append(_DepthDecodeCudaGraphLayerStage(*output)) + + last_stage = stages[-1] + last_attn_context = torch.empty(context_shape, device=device, dtype=dtype) + last_graph, last_output = _capture_cuda_graph( + lambda: self._depth_decode_last_post( + num_layers - 1, + last_stage.residual, + last_attn_context, + ), + device, + ) + post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context)) + return _DepthDecodeCudaGraph( + cache_key=self._depth_decode_key(next_input_ids, attention_bias), + pre_graph=pre_graph, + token_ids=token_ids, + cos=cos, + sin=sin, + positions=positions, + stages=tuple(stages), + post_graphs=tuple(post_graphs), + output=last_output, + ) + + def _get_depth_decode_graph( + self, + next_input_ids: torch.Tensor, + *, + past_length: int, + attention_bias: torch.Tensor, + ) -> _DepthDecodeCudaGraph: + key = self._depth_decode_key(next_input_ids, attention_bias) + decode_graph = self.graph + if decode_graph is None or decode_graph.cache_key != key: + decode_graph = self._build_depth_decode_graph( + next_input_ids, + past_length=past_length, + attention_bias=attention_bias, + ) + self.graph = decode_graph + else: + decode_graph.token_ids.copy_(next_input_ids) + self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length) + return decode_graph + + def _run_depth_decode_attention_core( + self, + layer_idx: int, + stage: _DepthDecodeCudaGraphLayerStage, + *, + past_key_values: Cache, + attention_bias: torch.Tensor, + cache_position: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + attention = self.backbone.transformer.blocks[layer_idx].self_attn + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + stage.key, + stage.value, + layer_idx, + cache_kwargs, + ) + key_states = _repeat_kv(key_states, attention.num_key_value_groups) + value_states = _repeat_kv(value_states, attention.num_key_value_groups) + attn_output = F.scaled_dot_product_attention( + stage.query, + key_states, + value_states, + attn_mask=attention_bias, + dropout_p=0.0, + is_causal=False, + ) + return attn_output.transpose(1, 2) + + def run( + self, + next_input_ids: torch.Tensor, + *, + past_key_values: Cache, + attention_bias: torch.Tensor, + past_length: int, + ) -> tuple[torch.Tensor, Cache]: + end = past_length + 1 + decode_graph = self._get_depth_decode_graph( + next_input_ids, + past_length=past_length, + attention_bias=attention_bias, + ) + cache_position = decode_graph.positions[past_length:end] + attention_bias_q = attention_bias[:, :, past_length:end, :end] + + decode_graph.pre_graph.replay() + + for layer_idx, post_graph in enumerate(decode_graph.post_graphs): + attn_context = self._run_depth_decode_attention_core( + layer_idx, + decode_graph.stages[layer_idx], + past_key_values=past_key_values, + attention_bias=attention_bias_q, + cache_position=cache_position, + cos=decode_graph.cos, + sin=decode_graph.sin, + ) + post_graph.attn_context.copy_(attn_context) + post_graph.graph.replay() + + return decode_graph.output, past_key_values + + +def _cuda_graph_tensor_signature( + tensor: torch.Tensor | None, +) -> tuple[Any, ...] | None: + if tensor is None: + return None + return ( + tuple(tensor.shape), + tuple(tensor.stride()), + str(tensor.dtype), + str(tensor.device), + ) + + +def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]: + sig = _cuda_graph_tensor_signature + return ( + tuple((sig(k), sig(v)) for k, v in context.kv_contexts), + sig(context.cross_mask), + sig(context.self_mask), + sig(context.valid_action), + None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache), + ) + + +def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]: + sig = _cuda_graph_tensor_signature + return tuple( + ( + sig(step.conditioning), + tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations), + tuple(sig(t) for t in step.final_modulation), + ) + for step in modulations + ) + + +def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]: + sig = _cuda_graph_tensor_signature + return ( + sig(inputs.trajectory), + _cuda_graph_context_signature(inputs.context), + _cuda_graph_modulation_signature(inputs.modulations), + sig(inputs.action_dim_is_pad), + int(steps), + ) + + +def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None: + return None + static = torch.empty_strided( + tuple(tensor.shape), + tuple(tensor.stride()), + device=tensor.device, + dtype=tensor.dtype, + ) + static.copy_(tensor) + return static + + +def _clone_static_context(context: Any) -> Any: + rope_cache = None + if context.rope_cache is not None: + rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache) + return context.__class__( + kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts), + cross_mask=_clone_static_tensor(context.cross_mask), + self_mask=_clone_static_tensor(context.self_mask), + valid_action=_clone_static_tensor(context.valid_action), + rope_cache=rope_cache, + ) + + +def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]: + return tuple( + step.__class__( + conditioning=_clone_static_tensor(step.conditioning), + block_modulations=tuple( + tuple(_clone_static_tensor(t) for t in block_modulation) + for block_modulation in step.block_modulations + ), + final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation), + ) + for step in modulations + ) + + +def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs: + return _ActionFlowInputs( + trajectory=_clone_static_tensor(inputs.trajectory), + context=_clone_static_context(inputs.context), + modulations=_clone_static_modulations(inputs.modulations), + action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad), + ) + + +def _copy_context_(dst: Any, src: Any) -> None: + for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts): + dst_k.copy_(src_k) + dst_v.copy_(src_v) + if src.cross_mask is not None: + dst.cross_mask.copy_(src.cross_mask) + if src.self_mask is not None: + dst.self_mask.copy_(src.self_mask) + if src.valid_action is not None: + dst.valid_action.copy_(src.valid_action) + if src.rope_cache is not None: + for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache): + dst_tensor.copy_(src_tensor) + + +def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None: + dst.trajectory.copy_(src.trajectory) + _copy_context_(dst.context, src.context) + if src.action_dim_is_pad is not None: + dst.action_dim_is_pad.copy_(src.action_dim_is_pad) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _capture_cuda_graph( + fn, + device: torch.device, + *, + after_warmup=None, +) -> tuple[torch.cuda.CUDAGraph, Any]: + warmup_stream = torch.cuda.Stream(device=device) + warmup_stream.wait_stream(torch.cuda.current_stream(device)) + with torch.cuda.stream(warmup_stream): + fn() + torch.cuda.current_stream(device).wait_stream(warmup_stream) + if after_warmup is not None: + after_warmup() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + output = fn() + return graph, output diff --git a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py new file mode 100644 index 000000000..4c36b04c8 --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py @@ -0,0 +1,4591 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +"""Modeling code for MolmoAct2""" + +import json +import math +import os +import re +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from collections.abc import Callable, Mapping, Sequence + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask, create_masks_for_generate +from transformers.modeling_flash_attention_utils import ( + FlashAttentionKwargs, + _flash_attention_forward, + flash_attn_supports_top_left_mask, +) +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + ModelOutput, + TransformersKwargs, + can_return_tuple, + logging, +) + +from .configuration_molmoact2 import ( + MolmoAct2ActionExpertConfig, + MolmoAct2AdapterConfig, + MolmoAct2Config, + MolmoAct2TextConfig, + MolmoAct2VitConfig, +) +from .inference import ( + ActionCudaGraphManager, + DepthDecodeCudaGraphManager, + _ActionFlowInputs, + _cache_max_len_int, + _cache_seq_len_int, + _iter_cache_key_values, +) + +logger = logging.get_logger(__name__) + + +ACTION_START_TOKEN = "" # nosec B105 +ACTION_END_TOKEN = "" # nosec B105 +ACTION_OUTPUT_TOKEN = "" # nosec B105 +STATE_START_TOKEN = "" # nosec B105 +STATE_END_TOKEN = "" # nosec B105 +STATE_TOKEN_PREFIX = " torch.Tensor: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def _round_up_multiple(value: int, multiple_of: int) -> int: + if multiple_of <= 0: + return value + return int(math.ceil(value / multiple_of) * multiple_of) + + +def _init_linear(linear: nn.Linear, *, zero: bool = False, scale: float = 1.0) -> None: + if zero: + nn.init.zeros_(linear.weight) + else: + nn.init.xavier_uniform_(linear.weight) + if scale != 1.0: + with torch.no_grad(): + linear.weight.mul_(scale) + if linear.bias is not None: + nn.init.zeros_(linear.bias) + + +@dataclass +class ActionExpertContext: + kv_contexts: Sequence[tuple[torch.Tensor, torch.Tensor]] + cross_mask: torch.Tensor | None + self_mask: torch.Tensor | None + valid_action: torch.Tensor | None + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None + + +@dataclass +class ActionExpertStepModulation: + conditioning: torch.Tensor + block_modulations: Sequence[tuple[torch.Tensor, ...]] + final_modulation: tuple[torch.Tensor, torch.Tensor] + + +class ActionExpertRMSNorm(nn.Module): + def __init__( + self, + size: int, + *, + eps: float = 1e-6, + elementwise_affine: bool = False, + device=None, + ) -> None: + super().__init__() + self.size = size + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(size, device=device)) + else: + self.register_parameter("weight", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + dtype = x.dtype + x_float = x.to(torch.float32) + variance = x_float.pow(2).mean(dim=-1, keepdim=True) + out = x_float * torch.rsqrt(variance + self.eps) + out = out.to(dtype) + if self.weight is not None: + out = out * self.weight + return out + + def reset_parameters(self) -> None: + if self.weight is not None: + nn.init.ones_(self.weight) + + +class ActionExpertRotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, base: float = 10000.0) -> None: + super().__init__() + if head_dim % 2 != 0: + raise ValueError("RoPE requires an even head_dim.") + self.head_dim = head_dim + self.base = base + + def build_cache( + self, + *, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + half_dim = self.head_dim // 2 + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, half_dim, device=device, dtype=torch.float32) / max(half_dim, 1)) + ) + positions = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(positions, inv_freq) + cos = freqs.cos().to(dtype=dtype).view(1, 1, seq_len, half_dim) + sin = freqs.sin().to(dtype=dtype).view(1, 1, seq_len, half_dim) + return cos, sin + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + *, + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if rope_cache is None: + rope_cache = self.build_cache(seq_len=q.shape[-2], device=q.device, dtype=q.dtype) + cos, sin = rope_cache + half_dim = self.head_dim // 2 + + def _apply(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x[..., :half_dim], x[..., half_dim:] + return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + return _apply(q), _apply(k) + + +class ActionExpertSelfAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + *, + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + qk_norm: bool = True, + qk_norm_eps: float = 1e-6, + use_rope: bool = True, + ) -> None: + super().__init__() + if hidden_size % num_heads != 0: + raise ValueError("hidden_size must be divisible by num_heads") + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.attn_dropout = attn_dropout + self.q_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None + self.k_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None + self.rope = ActionExpertRotaryEmbedding(self.head_dim) if use_rope else None + self.qkv = nn.Linear(hidden_size, hidden_size * 3) + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.out_drop = nn.Dropout(proj_dropout) + + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.q_norm is None or self.k_norm is None: + return q, k + return self.q_norm(q), self.k_norm(k) + + def _attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + attn_mask: torch.Tensor | None = None, + is_causal: bool = False, + ) -> torch.Tensor: + dropout_p = self.attn_dropout if self.training else 0.0 + out = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + return out.transpose(1, 2).contiguous() + + def forward( + self, + x: torch.Tensor, + *, + attn_mask: torch.Tensor | None = None, + is_causal: bool = False, + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + bsz, seq_len, _ = x.shape + qkv = self.qkv(x).view(bsz, seq_len, 3, self.num_heads, self.head_dim) + q = qkv[:, :, 0].transpose(1, 2) + k = qkv[:, :, 1].transpose(1, 2) + v = qkv[:, :, 2].contiguous() + q, k = self._apply_qk_norm(q, k) + if self.rope is not None: + q, k = self.rope(q, k, rope_cache=rope_cache) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + out = self._attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal) + out = out.reshape(bsz, seq_len, self.hidden_size) + return self.out_drop(self.out_proj(out)) + + +class ActionExpertCrossAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + *, + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + qk_norm: bool = True, + qk_norm_eps: float = 1e-6, + ) -> None: + super().__init__() + if hidden_size % num_heads != 0: + raise ValueError("hidden_size must be divisible by num_heads") + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.attn_dropout = attn_dropout + self.q_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None + self.k_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None + self.q_proj = nn.Linear(hidden_size, hidden_size) + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.out_drop = nn.Dropout(proj_dropout) + + def _as_heads(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() == 4: + if x.shape[2] == self.num_heads: + return x + if x.shape[1] == self.num_heads: + return x.transpose(1, 2).contiguous() + raise ValueError(f"Unexpected cross-attention KV shape {tuple(x.shape)}") + if x.dim() != 3: + raise ValueError(f"Expected 3D/4D cross-attention KV, got {tuple(x.shape)}") + bsz, seq_len, _ = x.shape + return x.view(bsz, seq_len, self.num_heads, self.head_dim) + + def _attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + dropout_p = self.attn_dropout if self.training else 0.0 + out = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + return out.transpose(1, 2).contiguous() + + def forward( + self, + x: torch.Tensor, + *, + kv_k: torch.Tensor, + kv_v: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + bsz, tgt_len, _ = x.shape + q = self.q_proj(x).view(bsz, tgt_len, self.num_heads, self.head_dim) + k = self._as_heads(kv_k) + v = self._as_heads(kv_v) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + if self.q_norm is not None: + q = self.q_norm(q) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + out = self._attention(q, k, v, attn_mask=attn_mask) + out = out.reshape(bsz, tgt_len, self.hidden_size) + return self.out_drop(self.out_proj(out)) + + +class ActionExpertMLP(nn.Module): + def __init__( + self, + hidden_size: int, + *, + mlp_ratio: float, + multiple_of: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + inner_dim = _round_up_multiple(int(hidden_size * mlp_ratio), multiple_of) + self.up_proj = nn.Linear(hidden_size, inner_dim) + self.gate_proj = nn.Linear(hidden_size, inner_dim) + self.down_proj = nn.Linear(inner_dim, hidden_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + x = self.dropout(x) + x = self.down_proj(x) + return self.dropout(x) + + +class ActionExpertModulation(nn.Module): + def __init__(self, hidden_size: int, num_chunks: int) -> None: + super().__init__() + self.act = nn.SiLU() + self.linear = nn.Linear(hidden_size, num_chunks * hidden_size) + + def forward(self, conditioning: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(conditioning)) + + +class ActionExpertBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + *, + mlp_ratio: float, + ffn_multiple_of: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + qk_norm: bool = True, + qk_norm_eps: float = 1e-6, + rope: bool = True, + ) -> None: + super().__init__() + self.self_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) + self.cross_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) + self.ff_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) + self.self_attn = ActionExpertSelfAttention( + hidden_size, + num_heads, + attn_dropout=attn_dropout, + proj_dropout=dropout, + qk_norm=qk_norm, + qk_norm_eps=qk_norm_eps, + use_rope=rope, + ) + self.cross_attn = ActionExpertCrossAttention( + hidden_size, + num_heads, + attn_dropout=attn_dropout, + proj_dropout=dropout, + qk_norm=qk_norm, + qk_norm_eps=qk_norm_eps, + ) + self.mlp = ActionExpertMLP( + hidden_size, + mlp_ratio=mlp_ratio, + multiple_of=ffn_multiple_of, + dropout=dropout, + ) + self.modulation = ActionExpertModulation(hidden_size, 9) + + def forward( + self, + x: torch.Tensor, + conditioning: torch.Tensor, + *, + cross_kv: tuple[torch.Tensor, torch.Tensor], + self_attn_mask: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, + is_causal: bool = False, + modulation: tuple[torch.Tensor, ...] | None = None, + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + if modulation is None: + modulation = self.modulation(conditioning).chunk(9, dim=1) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mca, + scale_mca, + gate_mca, + shift_mlp, + scale_mlp, + gate_mlp, + ) = modulation + x = x + gate_msa.unsqueeze(1) * self.self_attn( + _modulate(self.self_norm(x), shift_msa, scale_msa), + attn_mask=self_attn_mask, + is_causal=is_causal, + rope_cache=rope_cache, + ) + x = x + gate_mca.unsqueeze(1) * self.cross_attn( + _modulate(self.cross_norm(x), shift_mca, scale_mca), + kv_k=cross_kv[0], + kv_v=cross_kv[1], + attn_mask=attn_mask, + ) + x = x + gate_mlp.unsqueeze(1) * self.mlp(_modulate(self.ff_norm(x), shift_mlp, scale_mlp)) + return x + + +class ActionExpertFinalLayer(nn.Module): + def __init__(self, hidden_size: int, output_dim: int) -> None: + super().__init__() + self.norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) + self.modulation = ActionExpertModulation(hidden_size, 2) + self.linear = nn.Linear(hidden_size, output_dim) + + def forward( + self, + x: torch.Tensor, + conditioning: torch.Tensor, + *, + modulation: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + if modulation is None: + modulation = self.modulation(conditioning).chunk(2, dim=1) + shift, scale = modulation + return self.linear(_modulate(self.norm(x), shift, scale)) + + +class SinusoidalTimeEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + if timesteps.dim() > 1: + timesteps = timesteps.view(timesteps.shape[0], -1)[:, 0] + half_dim = self.dim // 2 + freq = torch.exp( + torch.arange(half_dim, device=timesteps.device, dtype=timesteps.dtype) + * (-math.log(10000.0) / max(half_dim - 1, 1)) + ) + args = timesteps[:, None] * freq[None, :] + emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if self.dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + +class ActionExpert(nn.Module): + """Modern MolmoAct2 action expert embedded in the local LeRobot implementation.""" + + def __init__( + self, + config: MolmoAct2ActionExpertConfig, + *, + llm_dim: int, + llm_kv_dim: int, + llm_num_layers: int, + device=None, + ): + super().__init__() + if config.num_layers != llm_num_layers: + raise ValueError( + "MolmoAct2 HF action expert supports only per-layer conditioning with one " + f"action block per LLM layer (action={config.num_layers}, llm={llm_num_layers})." + ) + self.config = config + self.hidden_size = config.hidden_size + self.llm_dim = llm_dim + self.llm_kv_dim = llm_kv_dim + self.action_head_dim = config.hidden_size // config.num_heads + + self.time_embed = nn.Sequential( + SinusoidalTimeEmbedding(config.timestep_embed_dim), + nn.Linear(config.timestep_embed_dim, config.hidden_size, device=device), + nn.SiLU(), + nn.Linear(config.hidden_size, config.hidden_size, device=device), + ) + self.action_embed = nn.Linear(config.max_action_dim, config.hidden_size, device=device) + self.context_k_proj = nn.Linear(self.llm_kv_dim, config.hidden_size, bias=False, device=device) + self.context_v_proj = nn.Linear(self.llm_kv_dim, config.hidden_size, bias=False, device=device) + self.context_norm = ( + ActionExpertRMSNorm(config.hidden_size, eps=1e-6) if config.context_layer_norm else nn.Identity() + ) + self._modulation_cache_key: tuple[Any, ...] | None = None + self._modulation_cache_value: Sequence[ActionExpertStepModulation] | None = None + self.blocks = nn.ModuleList( + [ + ActionExpertBlock( + config.hidden_size, + config.num_heads, + mlp_ratio=config.mlp_ratio, + ffn_multiple_of=config.ffn_multiple_of, + attn_dropout=config.attn_dropout, + dropout=config.dropout, + qk_norm=config.qk_norm, + qk_norm_eps=config.qk_norm_eps, + rope=config.rope, + ) + for _ in range(config.num_layers) + ] + ) + self.final_layer = ActionExpertFinalLayer(config.hidden_size, config.max_action_dim) + self.reset_parameters() + + def reset_parameters(self) -> None: + for module in self.time_embed.modules(): + if isinstance(module, nn.Linear): + _init_linear(module) + _init_linear(self.action_embed) + _init_linear(self.context_k_proj) + _init_linear(self.context_v_proj) + if isinstance(self.context_norm, ActionExpertRMSNorm): + self.context_norm.reset_parameters() + residual_scale = (2 * max(self.config.num_layers, 1)) ** -0.5 + for block in self.blocks: + _init_linear(block.self_attn.qkv) + _init_linear(block.self_attn.out_proj, scale=residual_scale) + _init_linear(block.cross_attn.q_proj) + _init_linear(block.cross_attn.out_proj, scale=residual_scale) + _init_linear(block.mlp.up_proj) + _init_linear(block.mlp.gate_proj) + _init_linear(block.mlp.down_proj, scale=residual_scale) + _init_linear(block.modulation.linear, zero=True) + block.self_norm.reset_parameters() + block.cross_norm.reset_parameters() + block.ff_norm.reset_parameters() + if block.self_attn.q_norm is not None: + block.self_attn.q_norm.reset_parameters() + if block.self_attn.k_norm is not None: + block.self_attn.k_norm.reset_parameters() + if block.cross_attn.q_norm is not None: + block.cross_attn.q_norm.reset_parameters() + if block.cross_attn.k_norm is not None: + block.cross_attn.k_norm.reset_parameters() + self.final_layer.norm.reset_parameters() + _init_linear(self.final_layer.modulation.linear, zero=True) + _init_linear(self.final_layer.linear, zero=True) + + def _reshape_hidden_to_heads(self, x: torch.Tensor) -> torch.Tensor: + return x.view(x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim) + + def _time_conditioning(self, timesteps: torch.Tensor) -> torch.Tensor: + conditioning = self.time_embed[0](timesteps) + first_linear = self.time_embed[1] + if isinstance(first_linear, nn.Linear): + conditioning = conditioning.to(dtype=first_linear.weight.dtype) + for module in list(self.time_embed.children())[1:]: + conditioning = module(conditioning) + return conditioning + + def _project_kv_tensor(self, x: torch.Tensor, proj: nn.Linear) -> torch.Tensor: + flat = self.context_norm(proj(x)) + return self._reshape_hidden_to_heads(flat) + + def _prepare_kv_context( + self, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]], + ) -> Sequence[tuple[torch.Tensor, torch.Tensor]]: + if len(encoder_kv_states) != len(self.blocks): + raise ValueError( + f"Expected {len(self.blocks)} KV layers for per-layer conditioning, " + f"got {len(encoder_kv_states)}." + ) + kv_contexts = [] + for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states): + k_ctx = self._project_kv_tensor(k_in, self.context_k_proj) + v_ctx = self._project_kv_tensor(v_in, self.context_v_proj) + k_norm = block.cross_attn.k_norm + if k_norm is not None: + k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2) + kv_contexts.append((k_ctx, v_ctx)) + return kv_contexts + + @staticmethod + def _build_cross_attention_mask( + encoder_attention_mask: torch.Tensor | None, + batch_size: int, + dtype: torch.dtype, + ) -> torch.Tensor | None: + if encoder_attention_mask is None: + return None + mask = encoder_attention_mask[:, None, None, :].to(dtype=dtype) + return (1.0 - mask) * torch.finfo(dtype).min + + def _build_self_attention_mask( + self, + action_attention_mask: torch.Tensor | None, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor | None: + mask = None + if action_attention_mask is not None: + valid = action_attention_mask.to(device=device, dtype=torch.bool) + key_mask = (~valid)[:, None, None, :].to(dtype=dtype) + mask = key_mask * torch.finfo(dtype).min + if self.config.causal_attn: + causal = torch.ones(seq_len, seq_len, device=device, dtype=torch.bool).triu(diagonal=1) + causal = causal.unsqueeze(0).unsqueeze(0).to(dtype=dtype) * torch.finfo(dtype).min + mask = causal if mask is None else mask + causal + return mask + + def prepare_context( + self, + *, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]], + encoder_attention_mask: torch.Tensor | None = None, + action_attention_mask: torch.Tensor | None = None, + state_embeddings: torch.Tensor | None = None, + batch_size: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> ActionExpertContext: + if state_embeddings is not None: + raise ValueError( + "MolmoAct2 HF action expert supports only discrete state tokens. " + "Continuous state embeddings are not supported." + ) + valid_action = None + if action_attention_mask is not None: + valid_action = action_attention_mask.to(device=device, dtype=dtype).unsqueeze(-1) + rope_cache = None + if len(self.blocks) > 0 and self.blocks[0].self_attn.rope is not None: + rope_cache = self.blocks[0].self_attn.rope.build_cache( + seq_len=seq_len, + device=device, + dtype=dtype, + ) + kv_contexts = self._prepare_kv_context(encoder_kv_states) + cross_mask = self._build_cross_attention_mask( + encoder_attention_mask, + batch_size, + dtype, + ) + self_mask = self._build_self_attention_mask(action_attention_mask, seq_len, device, dtype) + return ActionExpertContext( + kv_contexts=kv_contexts, + cross_mask=cross_mask, + self_mask=self_mask, + valid_action=valid_action, + rope_cache=rope_cache, + ) + + def prepare_modulation_cache( + self, + timesteps: Sequence[torch.Tensor], + ) -> Sequence[ActionExpertStepModulation]: + cache = [] + for idx, step_t in enumerate(timesteps): + conditioning = self._time_conditioning(step_t) + block_modulations = [] + for block in self.blocks: + block_modulations.append(tuple(block.modulation(conditioning).chunk(9, dim=1))) + final_modulation = tuple(self.final_layer.modulation(conditioning).chunk(2, dim=1)) + cache.append( + ActionExpertStepModulation( + conditioning=conditioning, + block_modulations=block_modulations, + final_modulation=final_modulation, + ) + ) + return cache + + def get_or_prepare_modulation_cache( + self, + timesteps: Sequence[torch.Tensor], + *, + cache_key: tuple[Any, ...] | None = None, + ) -> Sequence[ActionExpertStepModulation]: + if self.training or cache_key is None: + return self.prepare_modulation_cache(timesteps) + if self._modulation_cache_key == cache_key and self._modulation_cache_value is not None: + return self._modulation_cache_value + cached = self.prepare_modulation_cache(timesteps) + self._modulation_cache_key = cache_key + self._modulation_cache_value = cached + return cached + + def forward_with_context( + self, + actions: torch.Tensor, + timesteps: torch.Tensor, + *, + context: ActionExpertContext, + modulation: ActionExpertStepModulation | None = None, + ) -> torch.Tensor: + bsz, seq_len, _ = actions.shape + if seq_len > self.config.max_action_horizon: + raise ValueError( + f"Action sequence length {seq_len} exceeds configured max_action_horizon={self.config.max_action_horizon}" + ) + if modulation is None: + conditioning = self._time_conditioning(timesteps) + block_modulations: Sequence[tuple[torch.Tensor, ...] | None] = [None] * len(self.blocks) + final_modulation = None + else: + conditioning = modulation.conditioning + block_modulations = modulation.block_modulations + final_modulation = modulation.final_modulation + x = self.action_embed(actions) + if context.valid_action is not None: + x = x * context.valid_action + for idx, (block, kv_context, block_modulation) in enumerate( + zip(self.blocks, context.kv_contexts, block_modulations) + ): + x = block( + x, + conditioning, + cross_kv=kv_context, + self_attn_mask=context.self_mask, + attn_mask=context.cross_mask, + is_causal=self.config.causal_attn, + modulation=block_modulation, + rope_cache=context.rope_cache, + ) + if context.valid_action is not None: + x = x * context.valid_action + out = self.final_layer(x, conditioning, modulation=final_modulation) + if context.valid_action is not None: + out = out * context.valid_action + return out + + def forward( + self, + actions: torch.Tensor, + timesteps: torch.Tensor, + *, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]], + encoder_attention_mask: torch.Tensor | None = None, + action_attention_mask: torch.Tensor | None = None, + state_embeddings: torch.Tensor | None = None, + ) -> torch.Tensor: + bsz, seq_len, _ = actions.shape + context = self.prepare_context( + encoder_kv_states=encoder_kv_states, + encoder_attention_mask=encoder_attention_mask, + action_attention_mask=action_attention_mask, + state_embeddings=state_embeddings, + batch_size=bsz, + seq_len=seq_len, + device=actions.device, + dtype=actions.dtype, + ) + return self.forward_with_context(actions, timesteps, context=context) + + +def _to_numpy(value: Any) -> np.ndarray: + if isinstance(value, np.ndarray): + return value + if torch.is_tensor(value): + return value.detach().cpu().numpy() + return np.asarray(value) + + +def _to_array(value: Any) -> np.ndarray | None: + if value is None: + return None + if torch.is_tensor(value): + tensor = value.detach() + if tensor.dtype in (torch.bfloat16, torch.float16): + tensor = tensor.float() + return tensor.cpu().numpy().astype(np.float32, copy=False) + return np.asarray(value, dtype=np.float32) + + +def _to_mask(value: Any, fallback_like: np.ndarray | None) -> np.ndarray | None: + if value is None: + return None + mask = np.asarray(value, dtype=np.bool_) + if fallback_like is not None and mask.shape != fallback_like.shape: + mask = np.broadcast_to(mask, fallback_like.shape) + return mask + + +def _feature_dim_from_stats(stats: Mapping[str, Any] | None) -> int | None: + if not isinstance(stats, Mapping): + return None + for key in ( + "mean", + "std", + "min", + "max", + "q01", + "q99", + "q10", + "q90", + "mask", + "names", + ): + value = stats.get(key) + if value is None: + continue + arr = np.asarray(value) + if arr.shape: + return int(arr.shape[-1]) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return int(len(value)) + return None + + +class _FeatureNormalizer: + def __init__( + self, + *, + mode: str, + mean: np.ndarray | None = None, + std: np.ndarray | None = None, + min_val: np.ndarray | None = None, + max_val: np.ndarray | None = None, + q_low: np.ndarray | None = None, + q_high: np.ndarray | None = None, + mask: np.ndarray | None = None, + zero_mask: np.ndarray | None = None, + ): + self.mode = mode + self.mean = mean + self.std = std + self.min_val = min_val + self.max_val = max_val + self.q_low = q_low + self.q_high = q_high + self.mask = mask + self.zero_mask = zero_mask + + @classmethod + def from_stats(cls, stats: Mapping[str, Any] | None, mode: str) -> Optional["_FeatureNormalizer"]: + if stats is None: + return None + raw_mask = stats.get("mask") if isinstance(stats, Mapping) else None + if mode == "none": + fallback = None + for key in ( + "mean", + "std", + "min", + "max", + "q01", + "q99", + "q10", + "q90", + "mask", + ): + fallback = _to_array(stats.get(key)) + if fallback is not None: + break + return cls(mode=mode, mask=_to_mask(raw_mask, fallback)) + if mode == "mean_std": + mean = _to_array(stats.get("mean")) + std = _to_array(stats.get("std")) + if mean is None or std is None: + raise ValueError("norm_mode='mean_std' requires mean and std stats.") + return cls(mode=mode, mean=mean, std=std, mask=_to_mask(raw_mask, mean)) + if mode == "min_max": + min_val = _to_array(stats.get("min")) + max_val = _to_array(stats.get("max")) + if min_val is None or max_val is None: + raise ValueError("norm_mode='min_max' requires min and max stats.") + return cls( + mode=mode, + min_val=min_val, + max_val=max_val, + mask=_to_mask(raw_mask, min_val), + zero_mask=(min_val == max_val), + ) + if mode in {"q01_q99", "q10_q90"}: + low_key, high_key = ("q01", "q99") if mode == "q01_q99" else ("q10", "q90") + q_low = _to_array(stats.get(low_key)) + q_high = _to_array(stats.get(high_key)) + if q_low is None or q_high is None: + raise ValueError(f"norm_mode={mode!r} requires {low_key} and {high_key} stats.") + min_val = _to_array(stats.get("min")) + max_val = _to_array(stats.get("max")) + fallback = min_val if min_val is not None else q_low + zero_mask = None if min_val is None or max_val is None else (min_val == max_val) + return cls( + mode=mode, + min_val=min_val, + max_val=max_val, + q_low=q_low, + q_high=q_high, + mask=_to_mask(raw_mask, fallback), + zero_mask=zero_mask, + ) + raise ValueError(f"Unsupported robot normalization mode {mode!r}.") + + def normalize(self, x: Any) -> Any: + arr = _to_array(x) + if arr is None: + return None + eps = 1e-6 + if self.mode == "none": + normed = arr + elif self.mode == "mean_std": + normed = (arr - self.mean) / np.maximum(self.std, eps) + elif self.mode == "min_max": + normed = 2.0 * (arr - self.min_val) / np.maximum(self.max_val - self.min_val, eps) - 1.0 + elif self.mode in {"q01_q99", "q10_q90"}: + normed = 2.0 * (arr - self.q_low) / np.maximum(self.q_high - self.q_low, eps) - 1.0 + else: + normed = arr + if self.mode in {"min_max", "q01_q99", "q10_q90"}: + normed = np.clip(normed, -1.0, 1.0) + if self.mask is not None: + normed = np.where(self.mask, normed, arr) + if self.zero_mask is not None: + normed = np.where(self.zero_mask, 0.0, normed) + if torch.is_tensor(x): + return torch.as_tensor(normed, device=x.device, dtype=x.dtype) + return normed + + def unnormalize(self, x: Any) -> Any: + arr = _to_array(x) + if arr is None: + return None + if self.mode in {"min_max", "q01_q99", "q10_q90"}: + arr = np.clip(arr, -1.0, 1.0) + if self.mode == "none": + out = arr + elif self.mode == "mean_std": + out = arr * self.std + self.mean + elif self.mode == "min_max": + out = (arr + 1.0) * (self.max_val - self.min_val) / 2.0 + self.min_val + elif self.mode in {"q01_q99", "q10_q90"}: + out = (arr + 1.0) * (self.q_high - self.q_low) / 2.0 + self.q_low + else: + out = arr + if self.mask is not None: + out = np.where(self.mask, out, arr) + if torch.is_tensor(x): + return torch.as_tensor(out, device=x.device, dtype=x.dtype) + return out + + +class _RobotStats: + def __init__(self, payload: Mapping[str, Any]): + self.norm_mode = str(payload.get("norm_mode", "min_max")) + self.metadata_by_tag: dict[str, dict[str, Any]] = { + str(tag): dict(metadata or {}) + for tag, metadata in dict(payload.get("metadata_by_tag") or {}).items() + } + self.action_normalizers = {} + self.state_normalizers = {} + for tag, metadata in self.metadata_by_tag.items(): + if metadata.get("action_stats") is not None: + self.action_normalizers[tag] = _FeatureNormalizer.from_stats( + metadata.get("action_stats"), + self.norm_mode, + ) + if metadata.get("state_stats") is not None: + self.state_normalizers[tag] = _FeatureNormalizer.from_stats( + metadata.get("state_stats"), + self.norm_mode, + ) + + def validate_tag(self, norm_tag: str | None) -> str: + tag = str(norm_tag or "").strip() + if not tag: + raise ValueError("MolmoAct2 `predict_action` requires `norm_tag`.") + if tag not in self.metadata_by_tag: + allowed = ", ".join(sorted(self.metadata_by_tag)) + raise ValueError(f"Unknown MolmoAct2 normalization tag {tag!r}. Allowed tags: {allowed}.") + return tag + + def get_metadata(self, norm_tag: str | None) -> dict[str, Any]: + if norm_tag is None: + return {} + return dict(self.metadata_by_tag.get(str(norm_tag), {}) or {}) + + def normalize_state(self, state: Any, norm_tag: str) -> Any: + normalizer = self.state_normalizers.get(str(norm_tag)) + return state if normalizer is None else normalizer.normalize(state) + + def unnormalize_action(self, action: Any, norm_tag: str) -> Any: + normalizer = self.action_normalizers.get(str(norm_tag)) + return action if normalizer is None else normalizer.unnormalize(action) + + def get_action_dim(self, norm_tag: str) -> int | None: + metadata = self.get_metadata(norm_tag) + stats = metadata.get("action_stats") + dim = _feature_dim_from_stats(stats) + return dim + + def get_state_dim(self, norm_tag: str) -> int | None: + metadata = self.get_metadata(norm_tag) + return _feature_dim_from_stats(metadata.get("state_stats")) + + def get_action_horizon(self, norm_tag: str) -> int | None: + return self._get_positive_int(norm_tag, "action_horizon") + + def get_n_action_steps(self, norm_tag: str) -> int | None: + return self._get_positive_int(norm_tag, "n_action_steps") + + def _get_positive_int(self, norm_tag: str, key: str) -> int | None: + value = self.get_metadata(norm_tag).get(key) + if value is None: + return None + value = int(value) + if value < 1: + raise ValueError(f"Robot metadata for norm_tag={norm_tag!r} must define {key} >= 1.") + return value + + +def _normalize_image_for_cache(image: Any) -> np.ndarray: + arr = np.asarray(image) + if arr.ndim == 2: + arr = np.stack([arr] * 3, axis=-1) + if arr.ndim == 3 and arr.shape[0] in {1, 3, 4} and arr.shape[-1] not in {1, 3, 4}: + arr = np.moveaxis(arr, 0, -1) + if arr.ndim == 3 and arr.shape[-1] == 1: + arr = np.repeat(arr, 3, axis=-1) + if arr.dtype in (np.float32, np.float64): + if arr.size > 0 and float(arr.max()) <= 1.0: + arr = arr * 255.0 + arr = np.clip(arr, 0, 255).astype(np.uint8) + elif arr.dtype != np.uint8: + arr = np.clip(arr, 0, 255).astype(np.uint8) + return arr + + +def _extract_first_image(images: Any) -> np.ndarray | None: + if images is None: + return None + if isinstance(images, (list, tuple)): + if not images: + return None + return _normalize_image_for_cache(images[0]) + arr = _to_numpy(images) + if arr.ndim == 4: + return _normalize_image_for_cache(arr[0]) + return _normalize_image_for_cache(arr) + + +def _resize_depth_reasoning_image(image: np.ndarray, target_size: int) -> np.ndarray: + from PIL import Image + + if image.shape[0] == target_size and image.shape[1] == target_size: + return image + pil_image = Image.fromarray(np.asarray(image, dtype=np.uint8)) + return np.asarray(pil_image.resize((target_size, target_size), Image.BILINEAR)) + + +def _compute_depth_update_mask( + current_image: np.ndarray, + previous_image: np.ndarray, + *, + num_depth_codes: int, +) -> np.ndarray: + grid_side = int(math.isqrt(int(num_depth_codes))) + if grid_side * grid_side != int(num_depth_codes): + raise ValueError( + f"enable_adaptive_depth=True requires a square depth grid, got num_depth_codes={int(num_depth_codes)}." + ) + target_size = grid_side * _DEPTH_REASONING_PATCH_SIZE + current_resized = _resize_depth_reasoning_image(current_image, target_size).astype(np.float32) + previous_resized = _resize_depth_reasoning_image(previous_image, target_size).astype(np.float32) + current_patches = ( + current_resized.reshape( + grid_side, + _DEPTH_REASONING_PATCH_SIZE, + grid_side, + _DEPTH_REASONING_PATCH_SIZE, + 3, + ) + .transpose(0, 2, 1, 3, 4) + .reshape(grid_side, grid_side, -1) + ) + previous_patches = ( + previous_resized.reshape( + grid_side, + _DEPTH_REASONING_PATCH_SIZE, + grid_side, + _DEPTH_REASONING_PATCH_SIZE, + 3, + ) + .transpose(0, 2, 1, 3, 4) + .reshape(grid_side, grid_side, -1) + ) + dot = np.sum(current_patches * previous_patches, axis=-1) + norm_current = np.linalg.norm(current_patches, axis=-1) + norm_previous = np.linalg.norm(previous_patches, axis=-1) + denom = norm_current * norm_previous + similarity = np.where(denom < 1e-8, 1.0, dot / (denom + 1e-12)) + return np.asarray(similarity < _DEPTH_REASONING_THRESHOLD, dtype=np.bool_).reshape(-1) + + +def _build_depth_update_spans( + update_mask: Sequence[bool], +) -> list[tuple[int, int, bool]]: + flat_mask = np.asarray(update_mask, dtype=np.bool_).reshape(-1) + if flat_mask.size == 0: + return [] + spans: list[tuple[int, int, bool]] = [] + start = 0 + current_value = bool(flat_mask[0]) + for idx in range(1, int(flat_mask.shape[0])): + next_value = bool(flat_mask[idx]) + if next_value == current_value: + continue + spans.append((start, idx, current_value)) + start = idx + current_value = next_value + spans.append((start, int(flat_mask.shape[0]), current_value)) + return spans + + +def _wrap_setup_text(setup_type: str, add_setup_tokens: bool = False) -> str: + setup_type = str(setup_type or "") + if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith(SETUP_END_TOKEN): + return setup_type + if not setup_type or not add_setup_tokens: + return setup_type + return f"{SETUP_START_TOKEN}{setup_type}{SETUP_END_TOKEN}" + + +def _wrap_control_text(control_mode: str, add_control_tokens: bool = False) -> str: + control_mode = str(control_mode or "") + if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith(CONTROL_END_TOKEN): + return control_mode + if not control_mode or not add_control_tokens: + return control_mode + return f"{CONTROL_START_TOKEN}{control_mode}{CONTROL_END_TOKEN}" + + +def _discretize_normalized_state(state: np.ndarray, num_state_tokens: int) -> np.ndarray: + arr = np.asarray(state, dtype=np.float32) + arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0) + arr = np.clip(arr, -1.0, 1.0) + scaled = (arr + 1.0) / 2.0 * float(num_state_tokens - 1) + return np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1) + + +def _build_discrete_state_string(state: np.ndarray | None, num_state_tokens: int) -> str: + if state is None: + return "" + token_ids = _discretize_normalized_state(state, num_state_tokens).reshape(-1) + return f"{STATE_START_TOKEN}{''.join(f'{STATE_TOKEN_PREFIX}{int(token_id)}>' for token_id in token_ids)}{STATE_END_TOKEN}" + + +def _normalize_question_text(text: str) -> str: + normalized = re.sub(r"\s+", " ", text).strip() + if not normalized: + return "" + previous = None + while normalized and normalized != previous: + previous = normalized + normalized = normalized.strip().strip(_QUESTION_SURROUNDING_DELIMITERS).strip() + for pattern in _QUESTION_PREFIX_PATTERNS: + normalized = pattern.sub("", normalized, count=1).strip() + normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() + normalized = normalized.rstrip(_QUESTION_TRAILING_CLOSERS).rstrip() + normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() + sentence_chunks = [chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip()] + if len(sentence_chunks) > 1: + normalized = "; ".join(sentence_chunks) + normalized = normalized.lower() + return normalized + + +def _build_robot_text( + *, + task: str, + style: str, + discrete_state_string: str, + setup_type: str, + control_mode: str, + add_setup_tokens: bool, + add_control_tokens: bool, + num_images: int, +) -> str: + setup_text = _wrap_setup_text(setup_type, add_setup_tokens=add_setup_tokens) + control_text = _wrap_control_text(control_mode, add_control_tokens=add_control_tokens) + state_clause = ( + f" The current state of the robot is {discrete_state_string}." if discrete_state_string else "" + ) + if style == "robot_depth_action": + prompt = ( + f"The task is to {task}. The setup is {setup_text}.{state_clause} " + f"The expected control mode is {control_text}. Given these, first predict the depth map of the main image " + "and then predict the action the robot should take to complete the task?" + ) + trigger = f"{DEPTH_OUTPUT_TOKEN}{ACTION_OUTPUT_TOKEN}" + else: + prompt = ( + f"The task is to {task}. The setup is {setup_text}.{state_clause} " + f"The expected control mode is {control_text}. Given these, what action should the robot take to complete the task?" + ) + trigger = ACTION_OUTPUT_TOKEN + if num_images <= 0: + image_prefix = "" + elif num_images == 1: + image_prefix = "<|image|>" + else: + image_prefix = "".join(f"Image {idx + 1}<|image|>" for idx in range(num_images)) + return f"{image_prefix}<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{trigger}" + + +def _flatten_generated_token_ids(token_ids: torch.Tensor) -> list[int]: + if token_ids.ndim == 3: + return [int(x) for x in token_ids[0, 0].detach().cpu().tolist()] + if token_ids.ndim == 2: + return [int(x) for x in token_ids[0].detach().cpu().tolist()] + if token_ids.ndim == 1: + return [int(x) for x in token_ids.detach().cpu().tolist()] + raise ValueError(f"Unexpected generated token tensor shape {tuple(token_ids.shape)}") + + +def _extract_discrete_token_bins( + generated_ids: list[int], + start_token_id: int, + end_token_id: int, + token_id_to_bin: dict[int, int], +) -> list[int]: + start_idx = None + end_idx = None + for idx, token_id in enumerate(generated_ids): + if token_id == start_token_id: + start_idx = idx + break + if start_idx is not None: + for idx in range(start_idx + 1, len(generated_ids)): + if generated_ids[idx] == end_token_id: + end_idx = idx + break + span_start = 0 if start_idx is None else start_idx + 1 + span_end = len(generated_ids) if end_idx is None else end_idx + return [ + int(token_id_to_bin[token_id]) + for token_id in generated_ids[span_start:span_end] + if token_id in token_id_to_bin + ] + + +@dataclass +class MolmoAct2ActionOutput(ModelOutput): + actions: torch.FloatTensor | None = None + generated_token_ids: torch.LongTensor | None = None + depth_bins: torch.LongTensor | None = None + depth_cache: dict[str, Any] | None = None + + +@dataclass +class _DepthPrefix: + token_ids: torch.Tensor + depth_bins: torch.Tensor + full_input_ids: torch.Tensor + attention_mask: torch.Tensor | None + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] + next_output: Any + past_key_values: Cache | None + + +@dataclass +class MolmoAct2CausalLMOutputWithPast(ModelOutput): + """ + Base class for MolmoAct2 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + +@dataclass +class MolmoAct2ModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for MolmoAct2 outputs, with hidden states and attentions. + + Args: + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_num_patches, hidden_size)`. + image_hidden_states of the model produced by the vision backbone + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + +class ViTMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + hidden_act: str, + device: str | torch.device = None, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device) + self.act = ACT2FN[hidden_act] + self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(self.act(self.w1(x))) + + +class ViTMultiHeadDotProductAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_key_value_heads: int, + head_dim: int, + use_bias: bool = True, + input_dim: int | None = None, + float32_attention: bool = True, + attention_dropout: float = 0.0, + residual_dropout: float = 0.0, + device: str | torch.device = None, + attn_implementation: str = "eager", + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.attn_implementation = attn_implementation + self.is_causal = False + + input_dim = input_dim or hidden_size + + self.wq = nn.Linear( + input_dim, + self.num_heads * self.head_dim, + bias=use_bias, + device=device, + ) + self.wk = nn.Linear( + input_dim, + self.num_key_value_heads * self.head_dim, + bias=use_bias, + device=device, + ) + self.wv = nn.Linear( + input_dim, + self.num_key_value_heads * self.head_dim, + bias=use_bias, + device=device, + ) + self.wo = nn.Linear( + self.num_heads * self.head_dim, + self.hidden_size, + ) + self.float32_attention = float32_attention + self.attention_dropout = attention_dropout + self.residual_dropout = nn.Dropout(residual_dropout) + self.sdpa_backend_list = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + + def _split_heads(self, hidden_states, num_heads) -> torch.Tensor: + return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states) -> torch.Tensor: + return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) + + def forward( + self, + inputs_q: torch.Tensor, + inputs_kv: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if inputs_kv is not None: + inputs_k = inputs_kv + inputs_v = inputs_kv + else: + inputs_k = inputs_q + inputs_v = inputs_q + + xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v) + + xq = self._split_heads(xq, self.num_heads) + xk = self._split_heads(xk, self.num_key_value_heads) + xv = self._split_heads(xv, self.num_key_value_heads) + + if self.num_heads != self.num_key_value_heads: + xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) + xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) + + og_dtype = xq.dtype + + if self.float32_attention: + xq = xq.to(torch.float) + xk = xk.to(torch.float) + + dropout_p = 0.0 if not self.training else self.attention_dropout + + if self.attn_implementation == "eager": + attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype) + attn_weights = F.dropout(attn_weights, p=dropout_p, training=self.training) + attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv) + + elif self.attn_implementation == "sdpa": + if self.float32_attention: + xv = xv.to(torch.float32) + + query = xq.transpose(1, 2).contiguous() + key = xk.transpose(1, 2).contiguous() + value = xv.transpose(1, 2).contiguous() + if inputs_kv is not None: + with sdpa_kernel(self.sdpa_backend_list): + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + dropout_p=dropout_p, + ).transpose(1, 2) + else: + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + dropout_p=dropout_p, + ).transpose(1, 2) + + elif self.attn_implementation == "flash_attention_2": + if xq.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.wq.weight.dtype + attn_output = _flash_attention_forward( + xq, + xk, + xv, + attention_mask=attn_mask, + query_length=inputs_q.shape[1], + is_causal=False, + dropout=dropout_p, + softmax_scale=xq.shape[-1] ** -0.5, + use_top_left_mask=flash_attn_supports_top_left_mask(), + target_dtype=target_dtype, + implementation=self.attn_implementation, + ) + else: + raise ValueError(f"Attention implementation {self.attn_implementation} not supported") + + attn_output = attn_output.to(og_dtype) + attn_output = self._merge_heads(attn_output) + attn_output = self.wo(attn_output) + attn_output = self.residual_dropout(attn_output) + + return attn_output + + +class MolmoAct2VisionBlock(nn.Module): + def __init__(self, config: MolmoAct2VitConfig, device: str | torch.device = None): + super().__init__() + self.attention = ViTMultiHeadDotProductAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + float32_attention=config.float32_attention, + attention_dropout=config.attention_dropout, + residual_dropout=config.residual_dropout, + device=device, + attn_implementation=config._attn_implementation, + ) + self.feed_forward = ViTMLP( + config.hidden_size, + config.intermediate_size, + config.hidden_act, + device=device, + ) + self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) + self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class MolmoAct2VisionBlockCollection(nn.Module): + def __init__(self, config: MolmoAct2VitConfig, device: str | torch.device = None): + super().__init__() + self.config = config + self.resblocks = nn.ModuleList( + [MolmoAct2VisionBlock(config, device) for _ in range(config.num_hidden_layers)] + ) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + hidden_states = [] + for r in self.resblocks: + x = r(x) + hidden_states.append(x) + return hidden_states + + +class MolmoAct2VisionTransformer(nn.Module): + def __init__(self, config: MolmoAct2VitConfig, device: str | torch.device = None): + super().__init__() + self.config = config + + # positional embeddings + self.scale = config.hidden_size**-0.5 + self.num_prefix_tokens: int = 0 # no class embeddings + self.positional_embedding = nn.Parameter( + torch.zeros(config.image_num_pos, config.hidden_size, device=device), + ) + + image_patch_size = config.image_patch_size + self.patch_embedding = nn.Linear( + image_patch_size * image_patch_size * 3, + config.hidden_size, + bias=True, + device=device, + ) + + self.transformer = MolmoAct2VisionBlockCollection(config, device) + + def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: + pos_emb = self.positional_embedding + + pos_emb = pos_emb.reshape( + ( + int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), + pos_emb.shape[1], + ) + ) + + (patch_num_0, patch_num_1) = patch_num + + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: + # Derived from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + # antialias: default True in jax.image.resize + pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) + pos_emb = F.interpolate( + pos_emb, + size=(patch_num_0, patch_num_1), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) + + pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) + x = x + pos_emb[None, :, :].to(x.dtype) + return x + + def forward(self, x: torch.Tensor, patch_num: int = None) -> list[torch.Tensor]: + """ + : param x: (batch_size, num_patch, n_pixels) + """ + if patch_num is None: + patch_num = self.config.image_num_patch + + B, N, D = x.shape + + x = self.patch_embedding(x) + + # class embeddings and positional embeddings + x = self.add_pos_emb(x, patch_num) + + hidden_states = self.transformer(x) + return hidden_states + + +class ImageProjectorMLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + hidden_act: str, + device: str | torch.device = None, + ): + super().__init__() + self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) + self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device) + self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) + self.act = ACT2FN[hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(self.act(self.w1(x)) * self.w3(x)) + + +class MolmoAct2VisionBackbone(nn.Module): + def __init__(self, vit_config: MolmoAct2VitConfig, adapter_config: MolmoAct2AdapterConfig): + super().__init__() + self.vit_config = vit_config + self.adapter_config = adapter_config + + self.vit_layers = [] + for layer in adapter_config.vit_layers: + if layer >= 0: + self.vit_layers.append(layer) + else: + self.vit_layers.append(layer + vit_config.num_hidden_layers) + + last_layer_needed = max(self.vit_layers) + 1 + if last_layer_needed < vit_config.num_hidden_layers: + new_vit_config = deepcopy(vit_config) + new_vit_config.num_hidden_layers = last_layer_needed + self.image_vit = MolmoAct2VisionTransformer(new_vit_config) + else: + self.image_vit = MolmoAct2VisionTransformer(vit_config) + + self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens + + pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers) + self.image_pooling_2d = ViTMultiHeadDotProductAttention( + hidden_size=adapter_config.hidden_size, + num_heads=adapter_config.num_attention_heads, + num_key_value_heads=adapter_config.num_key_value_heads, + head_dim=adapter_config.head_dim, + input_dim=pool_dim, + float32_attention=adapter_config.float32_attention, + attention_dropout=adapter_config.attention_dropout, + residual_dropout=adapter_config.residual_dropout, + attn_implementation=adapter_config._attn_implementation, + ) + self.image_projector = ImageProjectorMLP( + adapter_config.hidden_size, + adapter_config.intermediate_size, + adapter_config.text_hidden_size, + adapter_config.hidden_act, + ) + self.image_feature_dropout = nn.Dropout(adapter_config.image_feature_dropout) + self.gradient_checkpointing = False + + def encode_image(self, images: torch.Tensor) -> torch.Tensor: + """ + : param images: (batch_size, num_crops, num_patch, n_pixels) + """ + batch_size, num_crops, num_patches, patch_dim = images.shape + images = images.view(batch_size * num_crops, num_patches, patch_dim) + + x = self.image_vit.patch_embedding(images) + x = self.image_vit.add_pos_emb(x, self.image_vit.config.image_num_patch) + + needed_layers = {int(layer) for layer in self.vit_layers} + selected_features: dict[int, torch.Tensor] = {} + use_checkpoint = bool(self.gradient_checkpointing and self.training and torch.is_grad_enabled()) + for layer_idx, block in enumerate(self.image_vit.transformer.resblocks): + if use_checkpoint: + x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) + else: + x = block(x) + if layer_idx in needed_layers: + selected_features[layer_idx] = x + + missing = needed_layers - set(selected_features) + if missing: + raise RuntimeError( + f"MolmoAct2 vision backbone did not produce requested layers: {sorted(missing)}." + ) + + image_features = torch.cat([selected_features[int(layer)] for layer in self.vit_layers], dim=-1) + + if self.num_prefix_tokens > 0: + image_features = image_features[:, 1:] + image_features = image_features.view(batch_size, num_crops, num_patches, -1) + return image_features + + @property + def dtype(self) -> torch.dtype: + return self.image_vit.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.image_vit.patch_embedding.weight.device + + def forward( + self, + images: torch.Tensor, + pooled_patches_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) + batch_size, num_image = images.shape[:2] + images = images.to(device=self.device) + if images.dtype == torch.uint8: + images = images.to(dtype=torch.float32) / 255.0 + images = images * 2.0 - 1.0 + elif torch.is_floating_point(images): + # Native MolmoAct2 eval keeps resized SigLIP pixels as uint8 and normalizes + # on device. Canonicalize HF processor floats to that exact grid. + images = torch.round(((images.to(dtype=torch.float32) + 1.0) * 0.5) * 255.0) + images = torch.clamp(images, 0.0, 255.0) / 255.0 + images = images * 2.0 - 1.0 + images = images.to(dtype=self.dtype) + image_features = self.encode_image(images) + + image_features = self.image_feature_dropout(image_features) + dim = image_features.shape[-1] + valid = pooled_patches_idx >= 0 + valid_token = torch.any(valid, -1) + + # Use `pooled_patches_idx` to arange the features for image pooling + batch_idx = torch.arange( + pooled_patches_idx.shape[0], + dtype=torch.long, + device=pooled_patches_idx.device, + ) + batch_idx = torch.tile( + batch_idx.view(batch_size, 1, 1), + [1, pooled_patches_idx.shape[1], pooled_patches_idx.shape[2]], + ) + + # Now [batch, num_high_res_features, pool_dim, dim] + to_pool = image_features.reshape(batch_size, -1, dim)[batch_idx, torch.clip(pooled_patches_idx, 0)] + to_pool = to_pool * valid.to(self.dtype)[:, :, :, None] + to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim]) + if self.adapter_config.pooling_attention_mask: + attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]]) + denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1) + denom = torch.where(denom == 0, 1, denom) + query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(to_pool.dtype) + else: + attn_mask = None + query = to_pool.mean(-2, keepdim=True) + pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask) + pooled_features = pooled_features.reshape([batch_size, -1, pooled_features.shape[-1]]) + + # MLP layer to map the feature. + pooled_features = self.image_projector(pooled_features) + return pooled_features.view(-1, pooled_features.shape[-1])[valid_token.flatten()] + + +# Copied from transformers.models.llama.modeling_llama.rotate_half + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MolmoAct2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__( + self, + config: MolmoAct2TextConfig, + device: str | torch.device = None, + rope_type: str | None = None, + ): + super().__init__() + if rope_type is not None: + self.rope_type = rope_type + elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + # BC: "rope_type" was originally "type" + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + if self.rope_type == "default": + self.rope_init_fn = self._default_rope_init + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=True) + self.original_inv_freq = self.inv_freq + self.register_buffer("_pos_sin_cache", torch.empty(0), persistent=False) + self.register_buffer("_pos_cos_cache", torch.empty(0), persistent=False) + + @staticmethod + def _default_rope_init( + config: MolmoAct2TextConfig, device: str | torch.device = None, **_ + ) -> tuple[torch.Tensor, float]: + inv_freq = 1.0 / ( + config.rope_theta + ** (torch.arange(0, config.head_dim, 2, dtype=torch.float32, device=device) / config.head_dim) + ) + return inv_freq, 1.0 + + def _target_cache_seq_len(self, x: torch.Tensor, position_ids: torch.Tensor | None) -> int: + if self.config.max_position_embeddings: + return int(self.config.max_position_embeddings) + if position_ids is not None: + return int(position_ids.max().item()) + 1 + return int(x.shape[-2]) + + def _rope_cache_ready(self, device: torch.device, seq_len: int) -> bool: + return ( + self._pos_sin_cache.numel() > 0 + and self._pos_sin_cache.device == device + and self._pos_cos_cache.device == device + and self._pos_sin_cache.shape[-2] >= seq_len + and self._pos_cos_cache.shape[-2] >= seq_len + ) + + def _refresh_inv_freq_if_needed(self, device: torch.device) -> None: + device = torch.device(device) + expected = int(self.config.head_dim) // 2 + needs_refresh = ( + self.inv_freq is None + or self._pos_sin_cache.numel() == 0 + or self.inv_freq.device.type == "meta" + or self.inv_freq.device != device + or self.inv_freq.numel() != expected + ) + if not needs_refresh: + inv_freq_cpu = self.inv_freq.detach() + needs_refresh = ( + not bool(torch.isfinite(inv_freq_cpu).all().item()) + or bool((inv_freq_cpu <= 0).any().item()) + or not bool(torch.isclose(inv_freq_cpu[0].cpu(), torch.tensor(1.0)).item()) + ) + if needs_refresh: + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=True) + self.original_inv_freq = self.inv_freq + self._pos_sin_cache = torch.empty(0, device=device) + self._pos_cos_cache = torch.empty(0, device=device) + + def _build_rope_cache(self, device: torch.device, seq_len: int) -> None: + device_type = device.type if device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + seq = torch.arange(seq_len, device=device, dtype=torch.float) + freqs = torch.einsum("i,j->ij", seq, self.inv_freq.to(device=device, dtype=torch.float)) + emb = torch.cat((freqs, freqs), dim=-1) + self._pos_sin_cache = emb.sin()[None, None, :, :] * self.attention_scaling + self._pos_cos_cache = emb.cos()[None, None, :, :] * self.attention_scaling + + @torch.no_grad() + def prepare_rope_cache( + self, + *, + device: str | torch.device, + max_seq_len: int | None = None, + ) -> None: + if self.rope_type != "default": + return + device = torch.device(device) + seq_len = int(max_seq_len or self.config.max_position_embeddings or 0) + if seq_len <= 0: + raise ValueError("RoPE cache preparation requires a positive max sequence length.") + if self._rope_cache_ready(device, seq_len): + return + self._refresh_inv_freq_if_needed(device) + self._build_rope_cache(device, seq_len) + + def _select_rope_cache( + self, + x: torch.Tensor, + position_ids: torch.Tensor | None, + seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + pos_sin = self._pos_sin_cache[:, :, :seq_len, :] + pos_cos = self._pos_cos_cache[:, :, :seq_len, :] + if position_ids is None: + sin = pos_sin[0, 0, : x.shape[-2], :] + cos = pos_cos[0, 0, : x.shape[-2], :] + else: + sin = pos_sin[0, 0][position_ids].view(position_ids.shape + (pos_sin.shape[-1],)) + cos = pos_cos[0, 0][position_ids].view(position_ids.shape + (pos_cos.shape[-1],)) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + seq_len = self._target_cache_seq_len(x, position_ids) + if not self._rope_cache_ready(x.device, seq_len): + self._refresh_inv_freq_if_needed(x.device) + self._build_rope_cache(x.device, seq_len) + return self._select_rope_cache(x, position_ids, seq_len) + + +class MolmoAct2RMSNorm(nn.Module): + def __init__( + self, + size: int, + eps: float = 1e-6, + device: str | torch.device = None, + ): + super().__init__() + self.weight = nn.Parameter(torch.ones(size, device=device)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + og_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(og_dtype) + + return self.weight * x + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class MolmoAct2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MolmoAct2TextConfig, layer_idx: int) -> None: + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = config.head_dim + self.scaling = self.head_dim**-0.5 + self.is_causal = True + + self.fused_dims = ( + config.num_attention_heads * config.head_dim, + config.head_dim * config.num_key_value_heads, + config.head_dim * config.num_key_value_heads, + ) + self.att_proj = nn.Linear( + config.hidden_size, + sum(self.fused_dims), + bias=config.qkv_bias, + ) + + # Layer norms. + self.k_norm: MolmoAct2RMSNorm | None = None + self.q_norm: MolmoAct2RMSNorm | None = None + self.qk_norm_type: str | None = None + if config.use_qk_norm: + k_norm_size = ( + config.head_dim + if config.qk_norm_type == "qwen3" + else config.num_key_value_heads * config.head_dim + ) + self.k_norm = MolmoAct2RMSNorm(k_norm_size, eps=config.layer_norm_eps) + q_norm_size = ( + config.head_dim + if config.qk_norm_type == "qwen3" + else config.num_attention_heads * config.head_dim + ) + self.q_norm = MolmoAct2RMSNorm(q_norm_size, eps=config.layer_norm_eps) + self.qk_norm_type = config.qk_norm_type + + self.attention_dropout = config.attention_dropout + + self.attn_out = nn.Linear( + config.head_dim * config.num_attention_heads, + config.hidden_size, + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.att_proj(hidden_states) + query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1) + value_states = value_states.view(hidden_shape) + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type != "qwen3": + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "qwen3": + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + collected_key_states = key_states + collected_value_states = value_states + + dropout_p = 0.0 if not self.training else self.attention_dropout + if self.config._attn_implementation == "sdpa" and ( + attention_mask is None or torch.is_tensor(attention_mask) + ): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=dropout_p, + is_causal=attention_mask is None, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = None + else: + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=dropout_p, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.attn_out(attn_output) + if collect_layer_kv_states: + return attn_output, attn_weights, collected_key_states, collected_value_states + return attn_output, attn_weights + + +class LanguageModelMLP(nn.Module): + def __init__( + self, + input_dim: int, + intermediate_size: int, + hidden_act: str, + device: str | torch.device = None, + ): + super().__init__() + self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device) + self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device) + self.act = ACT2FN[hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ff_proj(x) + x, gate = x.chunk(2, dim=-1) + x = self.act(gate) * x + x = self.ff_out(x) + return x + + +class MolmoAct2DecoderLayer(GradientCheckpointingLayer): + def __init__( + self, + config: MolmoAct2TextConfig, + layer_idx: int | None = None, + device: str | torch.device = None, + ): + super().__init__() + self.config = config + + self.self_attn = MolmoAct2Attention(config, layer_idx) + self.attn_norm = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) + self.dropout = nn.Dropout(config.residual_dropout) + self.mlp = LanguageModelMLP( + config.hidden_size, + config.intermediate_size, + config.hidden_act, + device=device, + ) + self.ff_norm = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + + # Self Attention + attention_outputs = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + collect_layer_kv_states=collect_layer_kv_states, + **kwargs, + ) + hidden_states = attention_outputs[0] + self_attn_weights = attention_outputs[1] + + hidden_states = residual + self.dropout(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.ff_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if collect_layer_kv_states: + outputs += (attention_outputs[2], attention_outputs[3]) + + return outputs + + +class MolmoAct2PostNormDecoderLayer(MolmoAct2DecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + + residual = hidden_states + + # Self Attention + attention_outputs = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + collect_layer_kv_states=collect_layer_kv_states, + **kwargs, + ) + hidden_states = attention_outputs[0] + self_attn_weights = attention_outputs[1] + hidden_states = self.attn_norm(hidden_states) + + hidden_states = residual + self.dropout(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.ff_norm(hidden_states) + + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if collect_layer_kv_states: + outputs += (attention_outputs[2], attention_outputs[3]) + + return outputs + + +class MolmoAct2Embedding(nn.Module): + def __init__( + self, + num_embeddings: int, + num_new_embeddings: int, + features: int, + device: str | torch.device = None, + ): + super().__init__() + self.embedding = nn.Parameter( + torch.zeros(num_embeddings, features, device=device), + ) + self.new_embedding = nn.Parameter( + torch.zeros(num_new_embeddings, features, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0)) + + +class MolmoAct2PreTrainedModel(PreTrainedModel): + config: MolmoAct2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "MolmoAct2DecoderLayer", + "MolmoAct2PostNormDecoderLayer", + "MolmoAct2VisionBlock", + "ViTMultiHeadDotProductAttention", + ] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": MolmoAct2DecoderLayer, + "attentions": MolmoAct2Attention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear,)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, MolmoAct2Embedding): + module.embedding.data.normal_(mean=0.0, std=std) + module.new_embedding.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MolmoAct2RMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + + +class MolmoAct2TextModel(MolmoAct2PreTrainedModel): + config: MolmoAct2TextConfig + _no_split_modules = ["MolmoAct2DecoderLayer", "MolmoAct2PostNormDecoderLayer"] + + def __init__(self, config: MolmoAct2TextConfig): + super().__init__(config) + if config.additional_vocab_size is not None: + self.wte = MolmoAct2Embedding( + config.vocab_size, + config.additional_vocab_size, + config.hidden_size, + ) + else: + self.wte = nn.Embedding(config.vocab_size, config.hidden_size) + self.emb_drop = nn.Dropout(config.embedding_dropout) + decoder_layer = MolmoAct2PostNormDecoderLayer if config.norm_after else MolmoAct2DecoderLayer + self.blocks = nn.ModuleList( + [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.ln_f = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + if config.rope_scaling_layers is not None: + self.rotary_embs = nn.ModuleDict( + { + "default": MolmoAct2RotaryEmbedding(config, rope_type="default"), + "scaling": MolmoAct2RotaryEmbedding(config), + } + ) + else: + self.rotary_emb = MolmoAct2RotaryEmbedding(config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def prepare_rope_cache( + self, + *, + device: str | torch.device, + max_seq_len: int | None = None, + ) -> None: + if self.config.rope_scaling_layers is not None: + for rotary_emb in self.rotary_embs.values(): + rotary_emb.prepare_rope_cache(device=device, max_seq_len=max_seq_len) + return + self.rotary_emb.prepare_rope_cache(device=device, max_seq_len=max_seq_len) + + def get_input_embeddings(self) -> torch.nn.Module: + return self.wte + + def set_input_embeddings(self, value: torch.nn.Module) -> None: + self.wte = value + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + if collect_layer_kv_states and past_key_values is not None: + raise ValueError("collect_layer_kv_states cannot be used with past_key_values.") + if collect_layer_kv_states: + use_cache = False + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + inputs_embeds = self.wte(input_ids) + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if torch.is_tensor(attention_mask) and attention_mask.ndim == 4: + causal_mask_mapping = attention_mask + elif not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + + # Create the mask + causal_mask_mapping = create_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + if self.config.rope_scaling_layers is not None: + position_embeddings_mapping = { + "default": self.rotary_embs["default"](hidden_states, position_ids), + "scaling": self.rotary_embs["scaling"](hidden_states, position_ids), + } + else: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + collected_kv_states = [] if collect_layer_kv_states else None + + for layer_idx, decoder_block in enumerate(self.blocks[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.config.rope_scaling_layers is not None: + position_embeddings_i = ( + position_embeddings_mapping["scaling"] + if layer_idx in self.config.rope_scaling_layers + else position_embeddings_mapping["default"] + ) + else: + position_embeddings_i = position_embeddings + + layer_outputs = decoder_block( + hidden_states, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings_i, + collect_layer_kv_states=collect_layer_kv_states, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + output_idx = 1 + if output_attentions: + all_self_attns += (layer_outputs[output_idx],) + output_idx += 1 + if collect_layer_kv_states: + collected_kv_states.append((layer_outputs[output_idx], layer_outputs[output_idx + 1])) + + hidden_states = self.ln_f(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=tuple(collected_kv_states) if collect_layer_kv_states else past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Adapted from transformers.models.gemma3.modeling_gemma3 +def token_type_ids_mask_function( + token_type_ids: torch.Tensor | None = None, +) -> Callable | None: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + # Do not return an additional mask in this case + if token_type_ids is None: + return None + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # If it's 1 for both query and key/value, we are in an image block + # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length + # Since vmap doesn't support `if statement` we workaround it with `torch.where` + safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) + token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx] + token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) + + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1) + + # This is bidirectional attention whenever we are dealing with image tokens + return is_image_block & is_image_block + + return inner_mask + + +class MolmoAct2Model(MolmoAct2PreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: MolmoAct2Config + + def __init__(self, config: MolmoAct2Config): + super().__init__(config) + self.transformer: MolmoAct2TextModel = MolmoAct2TextModel(config.text_config) + self.vision_backbone: MolmoAct2VisionBackbone | None = None + if config.vit_config is not None and config.adapter_config is not None: + self.vision_backbone = MolmoAct2VisionBackbone(config.vit_config, config.adapter_config) + llm_kv_dim = config.text_config.num_key_value_heads * config.text_config.head_dim + if config.add_action_expert: + self.action_expert = ActionExpert( + config.action_expert_config, + llm_dim=config.hidden_size, + llm_kv_dim=llm_kv_dim, + llm_num_layers=config.num_hidden_layers, + ) + else: + self.action_expert = None + if config.add_action_expert and config.action_expert_depth_gate: + if config.action_expert_depth_gate_per_layer: + self.action_expert_depth_gate = nn.ModuleList( + nn.Linear(llm_kv_dim, 1) for _ in range(config.action_expert_config.num_layers) + ) + else: + self.action_expert_depth_gate = nn.Linear(llm_kv_dim, 1) + self.reset_action_expert_depth_gate_parameters() + else: + self.action_expert_depth_gate = None + self._depth_gate_token_ids = self._resolve_depth_gate_token_ids() + self.action_cuda_graph_manager: ActionCudaGraphManager | None = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> torch.nn.Module: + return self.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module) -> None: + self.transformer.wte = value + + def set_decoder(self, decoder): + self.transformer = decoder + + def get_decoder(self): + return self.transformer + + @property + def device(self) -> torch.device: + return self.transformer.ln_f.weight.device + + def reset_action_expert_depth_gate_parameters(self) -> None: + if self.action_expert_depth_gate is None: + return + gates = ( + self.action_expert_depth_gate + if isinstance(self.action_expert_depth_gate, nn.ModuleList) + else [self.action_expert_depth_gate] + ) + for gate in gates: + nn.init.zeros_(gate.weight) + nn.init.constant_(gate.bias, float(self.config.action_expert_depth_gate_init_bias)) + + def _resolve_depth_gate_token_ids(self) -> tuple[int, ...]: + if not self.config.action_expert_depth_gate: + return () + token_ids = [] + for token_id in ( + self.config.depth_output_token_id, + self.config.depth_start_token_id, + self.config.depth_end_token_id, + ): + if token_id is not None: + token_ids.append(int(token_id)) + if self.config.depth_token_start_id is not None and int(self.config.num_depth_tokens or 0) > 0: + start = int(self.config.depth_token_start_id) + token_ids.extend(range(start, start + int(self.config.num_depth_tokens))) + return tuple(dict.fromkeys(token_ids)) + + def _require_action_expert(self) -> ActionExpert: + if self.action_expert is None: + raise RuntimeError("This MolmoAct2 checkpoint does not include an action expert.") + return self.action_expert + + def _cache_to_sequence(self, cache: torch.Tensor) -> torch.Tensor: + if cache.dim() != 4: + raise ValueError(f"Expected KV cache tensor with 4 dims, got shape {tuple(cache.shape)}") + head_candidates = { + self.config.text_config.num_key_value_heads, + self.config.text_config.num_attention_heads, + } + if cache.shape[1] in head_candidates: + bsz, n_heads, seq_len, head_dim = cache.shape + return cache.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * head_dim) + if cache.shape[2] in head_candidates: + bsz, seq_len, n_heads, head_dim = cache.shape + return cache.reshape(bsz, seq_len, n_heads * head_dim) + if cache.shape[1] <= cache.shape[2]: + bsz, n_heads, seq_len, head_dim = cache.shape + return cache.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * head_dim) + bsz, seq_len, n_heads, head_dim = cache.shape + return cache.reshape(bsz, seq_len, n_heads * head_dim) + + def _extract_kv_states(self, past_key_values: Cache) -> Sequence[tuple[torch.Tensor, torch.Tensor]]: + if past_key_values is None: + raise RuntimeError("Action generation requires past_key_values from the VLM forward pass.") + seq_len = _cache_seq_len_int(past_key_values) + kv_states = [] + for key, value in _iter_cache_key_values(past_key_values): + if key is None or value is None: + continue + if key.shape[-2] > seq_len: + key = key[..., :seq_len, :] + value = value[..., :seq_len, :] + kv_states.append((self._cache_to_sequence(key), self._cache_to_sequence(value))) + if len(kv_states) != self.config.action_expert_config.num_layers: + raise RuntimeError( + f"Expected {self.config.action_expert_config.num_layers} KV layers, got {len(kv_states)}." + ) + return kv_states + + @staticmethod + def _mask_discrete_output_span( + row_ids: torch.Tensor, + row_mask: torch.Tensor, + start_id: int | None, + end_id: int | None, + ) -> None: + if start_id is None or end_id is None: + return + start_positions = (row_ids == start_id).nonzero(as_tuple=False).flatten().tolist() + if not start_positions: + return + end_positions = (row_ids == end_id).nonzero(as_tuple=False).flatten().tolist() + end_ptr = 0 + for start_pos in start_positions: + while end_ptr < len(end_positions) and end_positions[end_ptr] < start_pos: + end_ptr += 1 + if end_ptr >= len(end_positions): + row_mask[start_pos:] = False + break + end_pos = end_positions[end_ptr] + row_mask[start_pos : end_pos + 1] = False + end_ptr += 1 + + def _get_encoder_attention_mask( + self, + input_ids: torch.Tensor | None, + attention_mask: torch.Tensor | None, + ) -> torch.Tensor | None: + if attention_mask is not None: + mask = attention_mask.to(dtype=torch.bool).clone() + elif input_ids is not None: + mask = input_ids != -1 + else: + return None + if self.config.action_mode != "both" or input_ids is None: + return mask + eos_id = getattr(self.config, "eos_token_id", None) + if eos_id is not None: + mask &= input_ids != int(eos_id) + for batch_idx in range(input_ids.shape[0]): + self._mask_discrete_output_span( + input_ids[batch_idx], + mask[batch_idx], + self.config.action_start_token_id, + self.config.action_end_token_id, + ) + return mask + + def _get_depth_token_mask( + self, + input_ids: torch.Tensor | None, + encoder_attention_mask: torch.Tensor | None, + ) -> torch.Tensor | None: + if not self.config.action_expert_depth_gate or input_ids is None or not self._depth_gate_token_ids: + return None + depth_token_ids = torch.as_tensor( + self._depth_gate_token_ids, + device=input_ids.device, + dtype=input_ids.dtype, + ) + depth_mask = (input_ids.unsqueeze(-1) == depth_token_ids).any(dim=-1) + if encoder_attention_mask is not None: + depth_mask = depth_mask & encoder_attention_mask.to(device=input_ids.device, dtype=torch.bool) + return depth_mask + + @staticmethod + def _depth_gate_from_source( + gate_head: nn.Linear, + *, + source: torch.Tensor, + depth_mask: torch.Tensor, + encoder_attention_mask: torch.Tensor | None, + ) -> torch.Tensor: + if source.ndim == 4: + source = source.reshape(source.shape[0], source.shape[1], -1) + if source.ndim != 3: + raise ValueError(f"Depth gate expected a 3D sequence tensor, got {tuple(source.shape)}.") + if encoder_attention_mask is not None: + valid_mask = encoder_attention_mask.to(device=source.device, dtype=torch.bool) + else: + valid_mask = torch.ones(depth_mask.shape, device=source.device, dtype=torch.bool) + depth_mask = depth_mask.to(device=source.device, dtype=torch.bool) + pool_mask = valid_mask & ~depth_mask + has_pool = pool_mask.any(dim=-1, keepdim=True) + pool_mask = torch.where(has_pool, pool_mask, valid_mask) + weights = pool_mask.to(dtype=source.dtype).unsqueeze(-1) + pooled = (source * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0) + gate_logits = gate_head(pooled.to(dtype=gate_head.weight.dtype)) + return torch.sigmoid(gate_logits).to(dtype=source.dtype) + + def _depth_gate_from_condition( + self, + *, + input_ids: torch.Tensor | None, + encoder_attention_mask: torch.Tensor | None, + layer_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] | None, + ) -> tuple[torch.Tensor | Sequence[torch.Tensor] | None, torch.Tensor | None]: + gate_head = self.action_expert_depth_gate + if gate_head is None: + return None, None + depth_mask = self._get_depth_token_mask(input_ids, encoder_attention_mask) + if depth_mask is None or layer_kv_states is None: + return None, depth_mask + sources = [value for _, value in layer_kv_states] + if isinstance(gate_head, nn.ModuleList): + if len(gate_head) != len(sources): + raise ValueError( + f"Depth gate layer count mismatch: gates={len(gate_head)}, sources={len(sources)}." + ) + gates = [ + self._depth_gate_from_source( + gate, + source=source, + depth_mask=depth_mask, + encoder_attention_mask=encoder_attention_mask, + ) + for gate, source in zip(gate_head, sources) + ] + return gates, depth_mask + gate = self._depth_gate_from_source( + gate_head, + source=sources[-1], + depth_mask=depth_mask, + encoder_attention_mask=encoder_attention_mask, + ) + return gate, depth_mask + + @staticmethod + def _depth_gate_for_layer( + gate: torch.Tensor | Sequence[torch.Tensor], + layer_idx: int, + *, + num_layers: int, + ) -> torch.Tensor: + if isinstance(gate, torch.Tensor): + return gate + if len(gate) != num_layers: + raise ValueError(f"Depth gate layer count mismatch: gates={len(gate)}, layers={num_layers}.") + return gate[layer_idx] + + def _apply_depth_gate_to_layer_kv_states( + self, + layer_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] | None, + depth_mask: torch.Tensor | None, + gate: torch.Tensor | Sequence[torch.Tensor] | None, + ) -> Sequence[tuple[torch.Tensor, torch.Tensor]] | None: + if layer_kv_states is None or depth_mask is None or gate is None: + return layer_kv_states + gated_kv = [] + for layer_idx, (key, value) in enumerate(layer_kv_states): + layer_gate = self._depth_gate_for_layer(gate, layer_idx, num_layers=len(layer_kv_states)) + mask = depth_mask.to(device=key.device, dtype=torch.bool) + view_shape = [mask.shape[0], mask.shape[1]] + [1] * (key.ndim - 2) + scale = torch.ones(view_shape, device=key.device, dtype=key.dtype) + gate_view = layer_gate.to(device=key.device, dtype=key.dtype).view( + layer_gate.shape[0], + *([1] * (key.ndim - 1)), + ) + scale = torch.where(mask.view(view_shape), gate_view, scale) + gated_kv.append((key * scale, value * scale)) + return gated_kv + + @staticmethod + def _action_dim_valid_mask( + target: torch.Tensor, + action_dim_is_pad: torch.Tensor | None, + ) -> torch.Tensor | None: + if action_dim_is_pad is None: + return None + mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) + if mask.ndim == 1: + mask = mask.unsqueeze(0) + if mask.shape[-1] != target.shape[-1]: + raise ValueError( + f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}." + ) + if mask.shape[0] == 1 and target.shape[0] != 1: + mask = mask.expand(target.shape[0], -1) + if mask.shape[0] != target.shape[0]: + raise ValueError( + f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}." + ) + while mask.ndim < target.ndim: + mask = mask.unsqueeze(1) + return mask + + @classmethod + def _mask_action_dim_tensor( + cls, + tensor: torch.Tensor, + *, + action_dim_is_pad: torch.Tensor | None, + enabled: bool, + ) -> torch.Tensor: + if not enabled: + return tensor + valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad) + if valid_mask is None: + return tensor + return tensor.masked_fill(~valid_mask, 0) + + def _run_action_flow_loop(self, inputs: _ActionFlowInputs, steps: int) -> torch.Tensor: + action_expert = self._require_action_expert() + dt = 1.0 / steps + trajectory = inputs.trajectory + action_dim_is_pad = inputs.action_dim_is_pad + mask_enabled = self.config.mask_action_dim_padding + for idx in range(steps): + velocity = action_expert.forward_with_context( + trajectory, + inputs.modulations[idx].conditioning, + context=inputs.context, + modulation=inputs.modulations[idx], + ) + velocity = self._mask_action_dim_tensor( + velocity, + action_dim_is_pad=action_dim_is_pad, + enabled=mask_enabled, + ) + trajectory = trajectory + dt * velocity + trajectory = self._mask_action_dim_tensor( + trajectory, + action_dim_is_pad=action_dim_is_pad, + enabled=mask_enabled, + ) + return trajectory + + def _resolve_action_horizon(self, action_horizon: int | None = None) -> int: + max_action_horizon = int(self.config.max_action_horizon or 1) + resolved = max_action_horizon if action_horizon is None else int(action_horizon) + if resolved < 1: + raise ValueError(f"action_horizon must be >= 1, got {resolved}.") + if resolved > max_action_horizon: + raise ValueError( + f"Requested action_horizon={resolved} exceeds checkpoint max_action_horizon={max_action_horizon}." + ) + return resolved + + @torch.no_grad() + def generate_actions_from_inputs( + self, + *, + input_ids: torch.LongTensor, + pixel_values: torch.Tensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.LongTensor | None = None, + states: torch.Tensor | None = None, + action_dim_is_pad: torch.Tensor | None = None, + action_horizon: int | None = None, + num_steps: int | None = None, + generator: torch.Generator | None = None, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + action_expert = self._require_action_expert() + if encoder_kv_states is None: + outputs = self( + input_ids=input_ids, + pixel_values=pixel_values, + image_token_pooling=image_token_pooling, + image_grids=image_grids, + image_num_crops=image_num_crops, + pixel_values_videos=pixel_values_videos, + video_token_pooling=video_token_pooling, + video_grids=video_grids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + use_cache=True, + ) + encoder_kv_states = self._extract_kv_states(outputs.past_key_values) + encoder_attention_mask = self._get_encoder_attention_mask(input_ids, attention_mask) + elif encoder_attention_mask is None: + encoder_attention_mask = self._get_encoder_attention_mask(input_ids, attention_mask) + + depth_gate, depth_mask = self._depth_gate_from_condition( + input_ids=input_ids, + encoder_attention_mask=encoder_attention_mask, + layer_kv_states=encoder_kv_states, + ) + encoder_kv_states = self._apply_depth_gate_to_layer_kv_states( + encoder_kv_states, + depth_mask, + depth_gate, + ) + steps = int(num_steps or self.config.flow_matching_num_steps) + if steps <= 0: + raise ValueError(f"num_steps must be >= 1, got {steps}.") + source_tensor = encoder_kv_states[0][0] + batch_size = source_tensor.shape[0] + device = source_tensor.device + action_horizon = self._resolve_action_horizon(action_horizon) + trajectory_dtype = action_expert.action_embed.weight.dtype + trajectory = torch.randn( + (batch_size, action_horizon, self.config.max_action_dim), + device=device, + dtype=trajectory_dtype, + generator=generator, + ) + trajectory = self._mask_action_dim_tensor( + trajectory, + action_dim_is_pad=action_dim_is_pad, + enabled=self.config.mask_action_dim_padding, + ) + action_context = action_expert.prepare_context( + encoder_kv_states=encoder_kv_states, + encoder_attention_mask=encoder_attention_mask, + state_embeddings=states, + batch_size=batch_size, + seq_len=trajectory.shape[1], + device=device, + dtype=trajectory.dtype, + ) + flow_timesteps = [ + torch.full((batch_size,), idx / steps, device=device, dtype=torch.float32) for idx in range(steps) + ] + modulation_cache = action_expert.get_or_prepare_modulation_cache( + flow_timesteps, + cache_key=(steps, batch_size, device, trajectory.dtype), + ) + flow_inputs = _ActionFlowInputs( + trajectory=trajectory, + context=action_context, + modulations=modulation_cache, + action_dim_is_pad=action_dim_is_pad, + ) + action_cuda_graph_manager = self.action_cuda_graph_manager + if action_cuda_graph_manager is not None and action_cuda_graph_manager.can_use_action_flow( + flow_inputs + ): + trajectory = action_cuda_graph_manager.run_action_flow( + flow_inputs, steps, self._run_action_flow_loop + ) + else: + trajectory = self._run_action_flow_loop(flow_inputs, steps) + return trajectory + + def build_batched_images( + self, + input_ids: torch.LongTensor, + pixel_values: torch.Tensor, + image_token_pooling: torch.Tensor, + image_grids: torch.Tensor, + image_num_crops: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # 1) Count the number of images in each example + raw_counts = (input_ids == self.config.image_end_token_id).sum(1) # [N] + total_images = int(image_grids.size(0)) + total_end_tokens = int(raw_counts.sum().item()) + if total_images <= 0: + counts = raw_counts.new_zeros(raw_counts.shape) + elif total_end_tokens == total_images: + counts = raw_counts + elif total_end_tokens == 2 * total_images: + counts = raw_counts // 2 + else: + raise ValueError( + "Could not infer image counts from image end tokens: " + f"end_tokens={total_end_tokens}, image_grids={total_images}." + ) + N = counts.size(0) + device = input_ids.device + + # Total number of images in the batch + num_images = total_images + + # Sanity check + assert image_grids.size(0) == num_images, ( + f"Expected {num_images} image grids, but got {image_grids.size(0)}" + ) + assert image_num_crops.size(0) == num_images, ( + f"Expected {num_images} image num crops, but got {image_num_crops.size(0)}" + ) + + # 1-1) Compute per-image pooled patch count from image grids + with torch.no_grad(): + first_prod = image_grids[:, :2].prod(dim=1) # [num_images] + second_prod = image_grids[:, 2:].prod(dim=1) # [num_images] + num_pooled_patches_per_image = (first_prod + second_prod).to( + image_num_crops.dtype + ) # [num_images] + + # pixel_values: [n_crops, n_patches, pixels_per_patch] + n_crops, n_patches, pixels_per_patch = pixel_values.shape + + # 2) Map each image index → example index + # Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2] + example_ids_for_image = torch.arange(N, device=device).repeat_interleave(counts) # [num_images] + assert example_ids_for_image.numel() == num_images + + # 2-1) Compute crops_per_example by summing per-image crop counts + crops_per_example = torch.zeros(N, dtype=image_num_crops.dtype, device=image_num_crops.device) + crops_per_example.index_add_(0, example_ids_for_image, image_num_crops) # [N] + + # 2-2) Per-image number of patches = (crops per image) * n_patches + patches_per_image = image_num_crops * n_patches # [num_images] + + # 2-3) Compute per-example per-image patch offsets + counts_list = counts.tolist() + index_offset_per_example_list = [] + offset_img = 0 + for c in counts_list: + per_img_patches = patches_per_image[offset_img : offset_img + c] # [c] + # Offsets: [0, img0_total_patches, img0+img1_total_patches, ...] + index_offset = [0] + per_img_patches.cumsum(0).tolist()[:-1] + index_offset_per_example_list.append(index_offset) + offset_img += c + + # 2-4) Compute num_pooled_patches_per_example + num_pooled_patches_per_example = torch.zeros( + N, + dtype=num_pooled_patches_per_image.dtype, + device=num_pooled_patches_per_image.device, + ) + num_pooled_patches_per_example.index_add_(0, example_ids_for_image, num_pooled_patches_per_image) + + # Sanity checks + total_crops = int(crops_per_example.sum().item()) + assert total_crops == n_crops, f"Expected {total_crops} crops, but got {n_crops}" + + total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item()) + assert total_num_pooled_patches == image_token_pooling.size(0), ( + f"Expected {total_num_pooled_patches} pooled patches, but got {image_token_pooling.size(0)}" + ) + + # 3) Build images tensor filled with -1 + M = int(crops_per_example.max().item()) + images = torch.full( + (N, M, n_patches, pixels_per_patch), + fill_value=-1, + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + + # 4) Fill images with per-example slices from pixel_values + offset_crop = 0 + for i in range(N): + num = int(crops_per_example[i].item()) + cur = pixel_values[offset_crop : offset_crop + num] # [num, n_patches, pixels_per_patch] + images[i, :num] = cur + offset_crop += num + + # Sanity check + assert offset_crop == n_crops + + # 5) Build new_token_pooling tensor filled with -1 + P = int(num_pooled_patches_per_example.max().item()) + _, dim = image_token_pooling.shape + new_token_pooling = torch.full( + (N, P, dim), + fill_value=-1, + dtype=image_token_pooling.dtype, + device=image_token_pooling.device, + ) + + # 6) Fill token_pooling with per-example slices, adding per-image patch offsets + patch_offset = 0 + img_offset = 0 + + for i, c in enumerate(counts_list): + num_patches = int(num_pooled_patches_per_example[i].item()) + + # Subsequence of pooled tokens belonging to this example + cur = image_token_pooling[patch_offset : patch_offset + num_patches].clone() # [num_patches, dim] + + index_offset_per_example = index_offset_per_example_list[i] # length = c + per_img_pooled = num_pooled_patches_per_image[img_offset : img_offset + c] # [c] + + assert len(index_offset_per_example) == per_img_pooled.numel() + + # Apply per-image offsets to the (ragged) subsequence + offset = 0 + for j in range(c): + index_offset = int(index_offset_per_example[j]) + n = int(per_img_pooled[j].item()) + cur_slice = cur[offset : offset + n] + + # Apply offset across all columns + cur[offset : offset + n] = torch.where( + cur_slice >= 0, + cur_slice + index_offset, + cur_slice, + ) + offset += n + + new_token_pooling[i, :num_patches] = cur + + patch_offset += num_patches + img_offset += c + + # Final sanity checks + assert patch_offset == total_num_pooled_patches + assert img_offset == num_images + + return images, new_token_pooling + + def build_batched_videos( + self, + input_ids: torch.LongTensor, + pixel_values_videos: torch.Tensor, + video_token_pooling: torch.Tensor, + video_grids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # 1) Count the number of videos in each example + if self.config.use_frame_special_tokens: + end_token_id = self.config.frame_end_token_id + else: + end_token_id = self.config.image_end_token_id + counts = (input_ids == end_token_id).any(dim=1).long() # [N] + N = counts.size(0) + device = input_ids.device + + # Total number of videos in the batch + num_videos = int(counts.sum().item()) + + # Sanity check + assert video_grids.size(0) == num_videos, ( + f"Expected {num_videos} videos, but got {video_grids.size(0)}" + ) + + video_num_frames = video_grids[:, 0] # [num_videos] + num_pooled_patches_per_video = video_grids.prod(dim=1) # [num_videos] + + # pixel_values_videos: [n_frames, n_patches, pixels_per_patch] + n_frames, n_patches, pixels_per_patch = pixel_values_videos.shape + + # 2) Map each video index -> example index + # Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2] + example_ids_for_video = torch.arange(N, device=device).repeat_interleave(counts) # [num_videos] + assert example_ids_for_video.numel() == num_videos + + # 2-1) Compute frames_per_example by summing per-video frame counts + frames_per_example = torch.zeros( + N, + dtype=video_num_frames.dtype, + device=device, + ) + frames_per_example.index_add_(0, example_ids_for_video, video_num_frames) # [N] + + # 2-2) Compute num_pooled_patches_per_example + num_pooled_patches_per_example = torch.zeros( + N, + dtype=num_pooled_patches_per_video.dtype, + device=num_pooled_patches_per_video.device, + ) + num_pooled_patches_per_example.index_add_( + 0, + example_ids_for_video, + num_pooled_patches_per_video, + ) + + # Sanity checks + total_frames = int(frames_per_example.sum().item()) + assert total_frames == n_frames, f"Expected {total_frames} frames, but got {n_frames}" + + total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item()) + assert total_num_pooled_patches == video_token_pooling.size(0), ( + f"Expected {total_num_pooled_patches} pooled patches, but got {video_token_pooling.size(0)}" + ) + + # 3) Build videos tensor filled with -1 + M = int(frames_per_example.max().item()) + videos = torch.full( + (N, M, n_patches, pixels_per_patch), + fill_value=-1, + dtype=pixel_values_videos.dtype, + device=device, + ) + + # 4) Fill videos with per-examples slices from pixel_values_videos + offset_frame = 0 + for i in range(N): + num = int(frames_per_example[i].item()) + cur = pixel_values_videos[offset_frame : offset_frame + num] # [num, n_patches, pixels_per_patch] + videos[i, :num] = cur + offset_frame += num + + # Sanity check + assert offset_frame == n_frames + + # 5) Build new token_pooling tensor filled with -1 + P = int(num_pooled_patches_per_example.max().item()) + _, dim = video_token_pooling.shape + new_token_pooling = torch.full( + (N, P, dim), + fill_value=-1, + dtype=video_token_pooling.dtype, + device=video_token_pooling.device, + ) + + # 6) Fill new token_pooling with per-examples slices from video_token_pooling + patch_offset = 0 + for i in range(N): + num_patches = int(num_pooled_patches_per_example[i].item()) + cur = video_token_pooling[patch_offset : patch_offset + num_patches] # [num_patches, dim] + new_token_pooling[i, :num_patches] = cur + patch_offset += num_patches + + # Final sanity checks + assert patch_offset == total_num_pooled_patches + + return videos, new_token_pooling + + def merge_visual_inputs( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if pixel_values is not None and pixel_values_videos is not None: + raise ValueError("pixel_values and pixel_values_videos are provided at the same time") + elif pixel_values is not None: + assert input_ids is not None + images, token_pooling = self.build_batched_images( + input_ids=input_ids, + pixel_values=pixel_values, + image_token_pooling=image_token_pooling, + image_grids=image_grids, + image_num_crops=image_num_crops, + ) + elif pixel_values_videos is not None: + assert input_ids is not None + images, token_pooling = self.build_batched_videos( + input_ids=input_ids, + pixel_values_videos=pixel_values_videos, + video_token_pooling=video_token_pooling, + video_grids=video_grids, + ) + else: + images, token_pooling = None, None + return images, token_pooling + + def build_input_embeddings( + self, + input_ids: torch.LongTensor, + images: torch.FloatTensor | None = None, # image inputs + token_pooling: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + x = self.transformer.wte(input_ids) + + image_features: torch.FloatTensor | None = None + if images is not None: + image_features = self.vision_backbone(images, token_pooling).to(x.device) + is_image_patch = input_ids.reshape(-1) == self.config.image_patch_id + if is_image_patch.sum() != len(image_features): + raise RuntimeError( + f"Expected {int(is_image_patch.sum())} image patch embeddings, got {len(image_features)}." + ) + flat_x = x.reshape(-1, x.shape[-1]).clone() + flat_x[is_image_patch] = flat_x[is_image_patch] + image_features + x = flat_x.reshape_as(x) + + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + return x, image_features + + def _build_native_attention_bias( + self, + *, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + token_type_ids: torch.Tensor | None, + past_key_values: Cache | None, + ) -> torch.Tensor: + if attention_mask is not None and attention_mask.ndim == 4: + return attention_mask.to(device=inputs_embeds.device) + batch_size, seq_len = inputs_embeds.shape[:2] + past_length = _cache_seq_len_int(past_key_values) + current_length = past_length + int(seq_len) + max_cache_len = _cache_max_len_int(past_key_values) + attention_mask_len = max_cache_len if max_cache_len > 0 else current_length + device = inputs_embeds.device + + if attention_mask is None: + positions = torch.arange(attention_mask_len, device=device) + valid_mask = positions.unsqueeze(0) < current_length + valid_mask = valid_mask.expand(batch_size, -1) + elif attention_mask.ndim == 2: + valid_mask = torch.zeros((batch_size, attention_mask_len), device=device, dtype=torch.bool) + source_mask = attention_mask.to(device=device, dtype=torch.bool) + copy_len = min(int(source_mask.shape[-1]), attention_mask_len) + if copy_len > 0: + valid_mask[:, :copy_len] = source_mask[:, :copy_len] + if attention_mask_len > current_length: + valid_mask[:, current_length:] = False + else: + raise ValueError(f"Unsupported attention_mask shape for MolmoAct2: {tuple(attention_mask.shape)}") + + valid_mask = valid_mask[:, None, None, :] + causal_mask = torch.tril( + torch.ones(attention_mask_len, attention_mask_len, device=device, dtype=torch.bool) + )[None, None, past_length:current_length, :attention_mask_len] + + if token_type_ids is not None and past_length == 0: + causal_mask = causal_mask.expand(batch_size, -1, -1, -1).clone() + image_mask = token_type_ids.to(device=device, dtype=torch.bool) + can_attend_back = image_mask[:, :, None] & image_mask[:, None, :] + image_len = min(int(token_type_ids.shape[1]), attention_mask_len) + causal_mask[:, :, :, :image_len] = ( + causal_mask[:, :, :, :image_len] | can_attend_back[:, None, :, :image_len] + ) + + allowed = valid_mask & causal_mask + return torch.where( + allowed, + torch.zeros((), device=device, dtype=inputs_embeds.dtype), + torch.full( + (), + torch.finfo(inputs_embeds.dtype).min, + device=device, + dtype=inputs_embeds.dtype, + ), + ) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: Cache | None = None, + token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MolmoAct2ModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + images, token_pooling = self.merge_visual_inputs( + input_ids=input_ids, + pixel_values=pixel_values, + image_token_pooling=image_token_pooling, + image_grids=image_grids, + image_num_crops=image_num_crops, + pixel_values_videos=pixel_values_videos, + video_token_pooling=video_token_pooling, + video_grids=video_grids, + ) + + if images is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both images and inputs_embeds at the same time.") + + if inputs_embeds is None: + inputs_embeds, image_features = self.build_input_embeddings( + input_ids, + images, + token_pooling, + ) + + if cache_position is None: + past_seen_tokens = _cache_seq_len_int(past_key_values) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if isinstance(attention_mask, dict): + causal_mask_mapping = attention_mask + else: + causal_mask_mapping = self._build_native_attention_bias( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + past_key_values=past_key_values, + ) + + outputs = self.transformer( + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + return MolmoAct2ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if images is not None else None, + ) + + +class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = [] # Weights are not tied + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: MolmoAct2Config + + def __init__(self, config: MolmoAct2Config): + super().__init__(config) + + self.model = MolmoAct2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.model.action_cuda_graph_manager = ActionCudaGraphManager(self.model) + self.depth_decode_cuda_graph_manager = DepthDecodeCudaGraphManager(self) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module) -> None: + self.model.transformer.wte = value + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + # Make modules available through conditional class for BC + @property + def language_model(self) -> torch.nn.Module: + return self.model.transformer + + @property + def vision_backbone(self) -> torch.nn.Module: + return self.model.vision_backbone + + def _get_robot_stats(self) -> _RobotStats: + stats = getattr(self, "_molmoact2_robot_stats", None) + if stats is not None: + return stats + filename = getattr(self.config, "norm_stats_filename", "norm_stats.json") + base_dir = getattr(self.config, "_name_or_path", None) or getattr(self, "name_or_path", None) + if not base_dir: + raise ValueError( + "MolmoAct2 normalization stats are not loaded and config._name_or_path is empty; " + "load the model from a converted HF directory containing norm_stats.json." + ) + stats_path = os.path.join(str(base_dir), filename) + if not os.path.isfile(stats_path): + try: + from huggingface_hub import hf_hub_download + + stats_path = hf_hub_download(str(base_dir), filename, repo_type="model") + except Exception as exc: + raise FileNotFoundError( + f"MolmoAct2 normalization stats file is missing: {stats_path}. " + "Converted checkpoints must include norm_stats.json." + ) from exc + with open(stats_path, encoding="utf-8") as f: + payload = json.load(f) + stats = _RobotStats(payload) + self._molmoact2_robot_stats = stats + return stats + + @staticmethod + def _move_inputs_to_device(inputs: Mapping[str, Any], device: torch.device) -> dict[str, Any]: + out = {} + for key, value in inputs.items(): + out[key] = value.to(device) if torch.is_tensor(value) else value + return out + + @staticmethod + def _drop_trivial_attention_mask(inputs: Mapping[str, Any]) -> dict[str, Any]: + out = dict(inputs) + attention_mask = out.get("attention_mask") + if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): + out.pop("attention_mask", None) + return out + + @staticmethod + def _count_images(images: Any) -> int: + if images is None: + return 0 + if isinstance(images, (list, tuple)): + return len(images) + arr = np.asarray(images) if not torch.is_tensor(images) else images + if getattr(arr, "ndim", 0) == 4: + return int(arr.shape[0]) + return 1 + + @staticmethod + def _build_action_dim_is_pad( + *, + action_dim: int, + max_action_dim: int, + batch_size: int, + device: torch.device, + ) -> torch.Tensor | None: + if int(action_dim) > int(max_action_dim): + raise ValueError( + f"Requested action_dim {int(action_dim)} exceeds checkpoint max_action_dim {int(max_action_dim)}." + ) + if int(action_dim) == int(max_action_dim): + return None + mask = torch.ones((int(batch_size), int(max_action_dim)), device=device, dtype=torch.bool) + mask[:, : int(action_dim)] = False + return mask + + @staticmethod + def _slice_action_dim(actions: torch.Tensor, action_dim: int) -> torch.Tensor: + if actions.shape[-1] < int(action_dim): + raise ValueError( + f"Requested action_dim {int(action_dim)} but chunk only has width {actions.shape[-1]}." + ) + return actions[..., : int(action_dim)] + + @staticmethod + def _slice_action_chunk( + actions: torch.Tensor, n_obs_steps: int, n_action_steps: int | None + ) -> torch.Tensor: + if n_action_steps is None: + return actions + start = int(n_obs_steps) - 1 + end = start + int(n_action_steps) + if end > actions.shape[1]: + raise ValueError(f"Requested actions up to {end} but model produced horizon {actions.shape[1]}.") + return actions[:, start:end] + + def _depth_token_id_to_bin(self) -> dict[int, int]: + if self.config.depth_token_start_id is None or int(self.config.num_depth_tokens or 0) <= 0: + return {} + start = int(self.config.depth_token_start_id) + return {start + idx: idx for idx in range(int(self.config.num_depth_tokens))} + + def _action_token_id_to_bin(self) -> dict[int, int]: + if self.config.action_token_start_id is None or int(self.config.num_action_tokens or 0) <= 0: + return {} + start = int(self.config.action_token_start_id) + return {start + idx: idx for idx in range(int(self.config.num_action_tokens))} + + def _require_eos_token_id(self) -> int: + eos_token_id = getattr(self.config, "eos_token_id", None) + if eos_token_id is None and getattr(self, "generation_config", None) is not None: + eos_token_id = getattr(self.generation_config, "eos_token_id", None) + if isinstance(eos_token_id, (list, tuple)): + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None: + raise RuntimeError( + "Discrete action generation requires `eos_token_id` in the converted HF config." + ) + return int(eos_token_id) + + def _decode_depth_bins_from_token_ids(self, token_ids: torch.Tensor) -> torch.Tensor: + if self.config.depth_start_token_id is None or self.config.depth_end_token_id is None: + raise RuntimeError("Depth generation requires / token IDs.") + token_id_to_bin = self._depth_token_id_to_bin() + if not token_id_to_bin: + raise RuntimeError("Depth generation requires indexed depth tokens in the converted config.") + depth_token_bins = _extract_discrete_token_bins( + _flatten_generated_token_ids(token_ids), + int(self.config.depth_start_token_id), + int(self.config.depth_end_token_id), + token_id_to_bin, + ) + if not depth_token_bins: + raise RuntimeError("Model generated no decodable depth tokens between /.") + return torch.as_tensor([depth_token_bins], device=self.device, dtype=torch.long) + + def _consume_generation_tokens( + self, + token_ids: torch.Tensor, + *, + past_key_values: Cache | None, + attention_mask: torch.Tensor | None, + ) -> tuple[MolmoAct2CausalLMOutputWithPast, torch.Tensor | None]: + if token_ids.ndim == 1: + next_input_ids = token_ids.unsqueeze(1) + elif token_ids.ndim == 2: + next_input_ids = token_ids + else: + raise ValueError(f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}.") + next_attention_mask = attention_mask + if next_attention_mask is not None: + past_length = _cache_seq_len_int(past_key_values) + required_len = int(past_length) + int(next_input_ids.shape[1]) + if int(next_attention_mask.shape[-1]) < required_len: + pad_len = required_len - int(next_attention_mask.shape[-1]) + next_attention_mask = torch.cat( + ( + next_attention_mask, + next_attention_mask.new_ones((next_input_ids.shape[0], pad_len)), + ), + dim=-1, + ) + past_length = _cache_seq_len_int(past_key_values) + output = self( + input_ids=next_input_ids, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + use_cache=True, + cache_position=( + torch.arange( + past_length, + past_length + int(next_input_ids.shape[1]), + device=next_input_ids.device, + ) + if past_key_values is not None + else None + ), + ) + return output, next_attention_mask + + def _make_depth_decode_attention_bias( + self, inputs: Mapping[str, Any], past_key_values: Cache + ) -> torch.Tensor: + layers = getattr(past_key_values, "layers", None) + max_cache_len = int(getattr(layers[0], "max_cache_len", 0)) if layers else 0 + if max_cache_len <= 0: + raise RuntimeError("Depth decode fast path requires a cache with a fixed maximum length.") + input_ids = inputs["input_ids"] + batch_size = int(input_ids.shape[0]) + device = input_ids.device + dtype = self.lm_head.weight.dtype + + positions = torch.arange(max_cache_len, device=device, dtype=torch.long) + valid_mask = torch.ones((batch_size, max_cache_len), device=device, dtype=torch.bool) + attention_mask = inputs.get("attention_mask") + if attention_mask is not None: + source_mask = attention_mask.to(device=device, dtype=torch.bool) + copy_len = min(int(source_mask.shape[-1]), max_cache_len) + if copy_len > 0: + valid_mask[:, :copy_len] = source_mask[:, :copy_len] + causal_mask = positions[None, :] <= positions[:, None] + allowed = causal_mask.unsqueeze(0) & valid_mask[:, None, :] + attention_bias = torch.where( + allowed[:, None, :, :], + torch.zeros((), device=device, dtype=dtype), + torch.full((), torch.finfo(dtype).min, device=device, dtype=dtype), + ) + return attention_bias + + def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: + # Skips MolmoAct2Embedding's per-call cat([base, new]); safe only for IDs + # below text_config.vocab_size. This includes released depth/action tokens. + wte = self.model.transformer.wte + base_embedding = getattr(wte, "embedding", None) + if base_embedding is None: + return wte(input_ids) + return F.embedding(input_ids, base_embedding) + + def _run_ar_decode_step( + self, + token_ids: torch.Tensor, + *, + past_key_values: Cache, + attention_bias: torch.Tensor, + ) -> tuple[torch.Tensor, Cache]: + if token_ids.ndim == 1: + next_input_ids = token_ids.unsqueeze(1) + elif token_ids.ndim == 2: + next_input_ids = token_ids + else: + raise ValueError(f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}.") + past_length = _cache_seq_len_int(past_key_values) + end = past_length + int(next_input_ids.shape[1]) + if self.depth_decode_cuda_graph_manager.can_use( + next_input_ids, + past_key_values=past_key_values, + attention_bias=attention_bias, + ): + return self.depth_decode_cuda_graph_manager.run( + next_input_ids, + past_key_values=past_key_values, + attention_bias=attention_bias, + past_length=past_length, + ) + cache_position = torch.arange(past_length, end, device=next_input_ids.device, dtype=torch.long) + attention_bias = attention_bias[:, :, past_length:end, :end] + inputs_embeds = self._embed_base_tokens(next_input_ids) + outputs = self.model.transformer( + attention_mask=attention_bias, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + cache_position=cache_position, + ) + return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values + + def _run_depth_decode_step( + self, + token_ids: torch.Tensor, + *, + past_key_values: Cache, + attention_bias: torch.Tensor, + ) -> tuple[torch.Tensor, Cache]: + return self._run_ar_decode_step( + token_ids, + past_key_values=past_key_values, + attention_bias=attention_bias, + ) + + def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor: + start = int(self.config.depth_token_start_id) + end_id = start + int(self.config.num_depth_tokens) + return F.linear(last_hidden, self.lm_head.weight[start:end_id]) + + def _max_depth_decode_steps(self) -> int: + return max( + int(self.config.num_depth_codes or 0) + 8, + self.model._resolve_action_horizon() * 16, + 1, + ) + + def _make_ar_decode_static_cache(self, inputs: Mapping[str, Any], max_steps: int) -> Cache: + prompt_len = inputs["input_ids"].shape[1] + return self.depth_decode_cuda_graph_manager.make_static_cache( + max_cache_len=prompt_len + max(1, int(max_steps)), + ) + + def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache: + prompt_len = inputs["input_ids"].shape[1] + action_horizon = self.model._resolve_action_horizon() + max_end_steps = max(8, action_horizon) + action_token_budget = max(1, action_horizon * 16) + return self.depth_decode_cuda_graph_manager.make_static_cache( + max_cache_len=prompt_len + self._max_depth_decode_steps() + max_end_steps + action_token_budget, + ) + + def _continue_discrete_generation_from_output( + self, + initial_output: MolmoAct2CausalLMOutputWithPast, + *, + past_key_values: Cache | None, + attention_mask: torch.Tensor | None, + end_token_id: int, + max_steps: int, + attention_bias: torch.Tensor | None = None, + ) -> torch.Tensor: + generated_tokens: list[torch.Tensor] = [] + current_output = initial_output + current_past_key_values = past_key_values + current_attention_mask = attention_mask + hit_end = False + for _ in range(int(max_steps)): + next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) + generated_tokens.append(next_token) + if bool((next_token == int(end_token_id)).all()): + hit_end = True + break + if attention_bias is None: + current_output, current_attention_mask = self._consume_generation_tokens( + next_token, + past_key_values=current_past_key_values, + attention_mask=current_attention_mask, + ) + current_past_key_values = current_output.past_key_values + else: + last_hidden, current_past_key_values = self._run_ar_decode_step( + next_token, + past_key_values=current_past_key_values, + attention_bias=attention_bias, + ) + current_output = MolmoAct2CausalLMOutputWithPast( + logits=self.lm_head(last_hidden), + past_key_values=current_past_key_values, + ) + if not generated_tokens: + raise RuntimeError("Discrete continuation generated no tokens.") + if not hit_end: + raise RuntimeError( + f"Discrete continuation did not emit end token {int(end_token_id)} within {int(max_steps)} steps." + ) + return torch.stack(generated_tokens, dim=1) + + def _generate_depth_prefix( + self, + inputs: Mapping[str, Any], + *, + latest_first_image: np.ndarray | None, + depth_cache: Mapping[str, Any] | None, + enable_adaptive_depth: bool, + ) -> _DepthPrefix: + if self.config.depth_start_token_id is None or self.config.depth_end_token_id is None: + raise RuntimeError("Depth reasoning requires single-token /.") + if self.config.depth_token_start_id is None or int(self.config.num_depth_tokens or 0) <= 0: + raise RuntimeError("Depth reasoning requires indexed depth tokens.") + batch_size = int(inputs["input_ids"].shape[0]) + if batch_size != 1 and enable_adaptive_depth: + raise ValueError("enable_adaptive_depth=True currently supports batch size 1.") + static_cache = self._make_depth_static_cache(inputs) + output = self(**inputs, use_cache=True, past_key_values=static_cache) + current_output = output + current_past_key_values = output.past_key_values + current_attention_mask = inputs.get("attention_mask") + generated_tokens: list[torch.Tensor] = [] + + if not enable_adaptive_depth: + hit_depth_end = False + max_steps = self._max_depth_decode_steps() + for _ in range(max_steps): + next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) + generated_tokens.append(next_token) + current_output, current_attention_mask = self._consume_generation_tokens( + next_token, + past_key_values=current_past_key_values, + attention_mask=current_attention_mask, + ) + current_past_key_values = current_output.past_key_values + if bool((next_token == int(self.config.depth_end_token_id)).all()): + hit_depth_end = True + break + if not generated_tokens: + raise RuntimeError("Depth generation produced no tokens.") + if not hit_depth_end: + raise RuntimeError(f"Depth generation did not emit within {max_steps} steps.") + depth_token_ids = torch.stack(generated_tokens, dim=1) + full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1) + full_attention_mask = None + if current_attention_mask is not None: + full_attention_mask = current_attention_mask[:, : full_input_ids.shape[1]] + encoder_kv_states = self.model._extract_kv_states(current_past_key_values) + return _DepthPrefix( + token_ids=depth_token_ids, + depth_bins=self._decode_depth_bins_from_token_ids(depth_token_ids), + full_input_ids=full_input_ids, + attention_mask=full_attention_mask, + encoder_kv_states=encoder_kv_states, + next_output=current_output, + past_key_values=current_past_key_values, + ) + + depth_start = torch.full( + (batch_size,), + int(self.config.depth_start_token_id), + device=self.device, + dtype=torch.long, + ) + code_token_ids = torch.arange( + int(self.config.depth_token_start_id), + int(self.config.depth_token_start_id) + int(self.config.num_depth_tokens), + device=self.device, + dtype=torch.long, + ) + depth_attention_bias = self._make_depth_decode_attention_bias(inputs, current_past_key_values) + generated_tokens.append(depth_start) + last_hidden, current_past_key_values = self._run_depth_decode_step( + depth_start, + past_key_values=current_past_key_values, + attention_bias=depth_attention_bias, + ) + previous_image = None + previous_bins = None + if depth_cache is not None: + previous_image = depth_cache.get("image") + previous_bins = depth_cache.get("depth_bins") + selective = ( + bool(enable_adaptive_depth) + and latest_first_image is not None + and previous_image is not None + and previous_bins is not None + ) + update_mask = None + previous_buffer_t = None + if selective: + previous_buffer = np.asarray(previous_bins, dtype=np.int64).reshape(-1) + if previous_buffer.shape[0] == int(self.config.num_depth_codes): + update_mask = _compute_depth_update_mask( + latest_first_image, + _normalize_image_for_cache(previous_image), + num_depth_codes=int(self.config.num_depth_codes), + ) + previous_buffer_t = ( + torch.from_numpy(previous_buffer) + .to( + device=self.device, + dtype=torch.long, + ) + .unsqueeze(0) + ) + else: + selective = False + + depth_bins = torch.zeros( + (batch_size, int(self.config.num_depth_codes)), + device=self.device, + dtype=torch.long, + ) + num_depth_codes = int(self.config.num_depth_codes) + if not selective or update_mask is None or previous_buffer_t is None: + for depth_idx in range(num_depth_codes): + depth_logits = self._project_depth_logits(last_hidden) + predicted_bins = depth_logits.squeeze(1).argmax(dim=-1) + depth_bins[:, depth_idx] = predicted_bins + chosen_token_ids = code_token_ids[predicted_bins] + generated_tokens.append(chosen_token_ids) + last_hidden, current_past_key_values = self._run_depth_decode_step( + chosen_token_ids, + past_key_values=current_past_key_values, + attention_bias=depth_attention_bias, + ) + else: + for start_idx, end_idx, should_generate in _build_depth_update_spans(update_mask): + if should_generate: + for depth_idx in range(start_idx, end_idx): + depth_logits = self._project_depth_logits(last_hidden) + predicted_bins = depth_logits.squeeze(1).argmax(dim=-1) + depth_bins[:, depth_idx] = predicted_bins + chosen_token_ids = code_token_ids[predicted_bins] + generated_tokens.append(chosen_token_ids) + last_hidden, current_past_key_values = self._run_depth_decode_step( + chosen_token_ids, + past_key_values=current_past_key_values, + attention_bias=depth_attention_bias, + ) + continue + replay_bins = previous_buffer_t[:, start_idx:end_idx].expand(batch_size, -1) + depth_bins[:, start_idx:end_idx] = replay_bins + replay_token_ids = code_token_ids[replay_bins] + generated_tokens.extend(replay_token_ids.unbind(dim=1)) + last_hidden, current_past_key_values = self._run_depth_decode_step( + replay_token_ids, + past_key_values=current_past_key_values, + attention_bias=depth_attention_bias, + ) + hit_depth_end = False + max_depth_end_steps = max(8, self.model._resolve_action_horizon()) + full_logits = self.lm_head(last_hidden) + for _ in range(max_depth_end_steps): + next_token = full_logits.squeeze(1).argmax(dim=-1) + generated_tokens.append(next_token) + last_hidden, current_past_key_values = self._run_depth_decode_step( + next_token, + past_key_values=current_past_key_values, + attention_bias=depth_attention_bias, + ) + full_logits = self.lm_head(last_hidden) + if bool((next_token == int(self.config.depth_end_token_id)).all()): + hit_depth_end = True + break + if not hit_depth_end: + raise RuntimeError( + f"Depth generation did not emit within {max_depth_end_steps} steps " + "after adaptive depth tokens." + ) + + depth_token_ids = torch.stack(generated_tokens, dim=1) + full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1) + attention_mask = inputs.get("attention_mask") + if attention_mask is not None: + full_attention_mask = torch.cat( + (attention_mask, attention_mask.new_ones(depth_token_ids.shape)), + dim=-1, + )[:, : full_input_ids.shape[1]] + else: + full_attention_mask = None + current_output = MolmoAct2CausalLMOutputWithPast( + logits=full_logits, + past_key_values=current_past_key_values, + ) + encoder_kv_states = self.model._extract_kv_states(current_past_key_values) + return _DepthPrefix( + token_ids=depth_token_ids, + depth_bins=depth_bins, + full_input_ids=full_input_ids, + attention_mask=full_attention_mask, + encoder_kv_states=encoder_kv_states, + next_output=current_output, + past_key_values=current_past_key_values, + ) + + def _decode_discrete_action_chunk( + self, + generated_token_ids: torch.Tensor, + *, + action_tokenizer: Any, + action_dim: int, + action_horizon: int, + ) -> torch.Tensor: + if action_tokenizer is None: + raise ValueError("inference_action_mode='discrete' requires an `action_tokenizer` input.") + if self.config.action_start_token_id is None or self.config.action_end_token_id is None: + raise RuntimeError("Discrete action generation requires / token IDs.") + token_id_to_bin = self._action_token_id_to_bin() + if not token_id_to_bin: + raise RuntimeError( + "Discrete action generation requires indexed action tokens in the converted config." + ) + discrete_token_ids = _extract_discrete_token_bins( + _flatten_generated_token_ids(generated_token_ids), + int(self.config.action_start_token_id), + int(self.config.action_end_token_id), + token_id_to_bin, + ) + if not discrete_token_ids: + raise RuntimeError( + "Model generated no decodable action tokens between /." + ) + try: + decoded = action_tokenizer.decode( + [discrete_token_ids], + time_horizon=int(action_horizon), + action_dim=int(action_dim), + ) + except TypeError: + decoded = action_tokenizer.decode([discrete_token_ids]) + action_chunk = np.asarray(decoded, dtype=np.float32) + if action_chunk.ndim == 1: + action_chunk = action_chunk[None, None, :] + elif action_chunk.ndim == 2: + action_chunk = action_chunk[None, :, :] + elif action_chunk.ndim > 3: + action_chunk = action_chunk.reshape(1, action_chunk.shape[-2], action_chunk.shape[-1]) + if action_chunk.ndim != 3: + raise RuntimeError(f"Decoded action chunk has unexpected shape {action_chunk.shape}.") + return torch.as_tensor(action_chunk, device=self.device, dtype=torch.float32) + + @torch.no_grad() + def predict_action( + self, + *, + processor: Any, + images: Any, + task: str, + state: Any, + norm_tag: str, + inference_action_mode: str | None = None, + enable_depth_reasoning: bool = False, + enable_adaptive_depth: bool = True, + depth_cache: Mapping[str, Any] | None = None, + action_tokenizer: Any = None, + num_steps: int | None = None, + n_action_steps: int | None = None, + generator: torch.Generator | None = None, + normalize_language: bool = True, + enable_cuda_graph: bool = True, + return_dict: bool = True, + ) -> MolmoAct2ActionOutput | torch.Tensor: + if state is None: + raise ValueError("MolmoAct2 `predict_action` requires `state` for discrete state prompting.") + if inference_action_mode is None: + raise ValueError( + "`inference_action_mode` must be provided explicitly as either 'continuous' or 'discrete'." + ) + inference_action_mode = str(inference_action_mode) + if inference_action_mode not in {"continuous", "discrete"}: + raise ValueError("inference_action_mode must be either 'continuous' or 'discrete'.") + if inference_action_mode == "continuous" and not bool(self.config.add_action_expert): + raise RuntimeError( + "inference_action_mode='continuous' requires an action expert, but this checkpoint " + "was converted with add_action_expert=False." + ) + if inference_action_mode == "continuous" and self.config.action_mode not in { + "continuous", + "both", + }: + raise ValueError( + "inference_action_mode='continuous' requires checkpoint action_mode in " + f"{{'continuous', 'both'}}, got {self.config.action_mode!r}." + ) + if inference_action_mode == "discrete": + if action_tokenizer is None: + raise ValueError("inference_action_mode='discrete' requires an `action_tokenizer` input.") + if self.config.action_mode not in {"discrete", "both"}: + raise ValueError( + "inference_action_mode='discrete' requires checkpoint action_mode in " + f"{{'discrete', 'both'}}, got {self.config.action_mode!r}." + ) + if enable_depth_reasoning and not bool(self.config.enable_depth_reasoning): + raise ValueError("this model was not trained with `--enable_depth_reasoning`.") + + stats = self._get_robot_stats() + norm_tag = stats.validate_tag(norm_tag) + metadata = stats.get_metadata(norm_tag) + normalized_state = np.asarray(stats.normalize_state(state, norm_tag), dtype=np.float32) + num_state_tokens = int(self.config.num_state_tokens or 0) + if num_state_tokens <= 0: + raise RuntimeError( + "Discrete state prompting requires indexed state tokens in the converted config." + ) + discrete_state_string = _build_discrete_state_string(normalized_state, num_state_tokens) + style = "robot_depth_action" if enable_depth_reasoning else "robot_action" + task_text = str(task or "") + if normalize_language: + task_text = _normalize_question_text(task_text) + text = _build_robot_text( + task=task_text, + style=style, + discrete_state_string=discrete_state_string, + setup_type=str(metadata.get("setup_type", "") or ""), + control_mode=str(metadata.get("control_mode", "") or ""), + add_setup_tokens=bool(self.config.add_setup_tokens), + add_control_tokens=bool(self.config.add_control_tokens), + num_images=self._count_images(images), + ) + inputs = processor(text=text, images=images, return_tensors="pt") + inputs = self._move_inputs_to_device(inputs, self.device) + inputs = self._drop_trivial_attention_mask(inputs) + + action_dim = stats.get_action_dim(norm_tag) + if action_dim is None: + action_dim = int(self.config.max_action_dim) + action_dim = int(action_dim) + max_action_horizon = self.model._resolve_action_horizon() + action_horizon = stats.get_action_horizon(norm_tag) or max_action_horizon + if int(action_horizon) > max_action_horizon: + raise ValueError( + f"Tag action_horizon={int(action_horizon)} exceeds checkpoint max_action_horizon={max_action_horizon}." + ) + generation_horizon = int(action_horizon) + resolved_n_action_steps = n_action_steps + if resolved_n_action_steps is None: + resolved_n_action_steps = stats.get_n_action_steps(norm_tag) + if resolved_n_action_steps is None: + resolved_n_action_steps = int(action_horizon) + resolved_n_action_steps = int(resolved_n_action_steps) + if resolved_n_action_steps < 1: + raise ValueError(f"n_action_steps must be >= 1, got {resolved_n_action_steps}.") + if resolved_n_action_steps > int(action_horizon): + raise ValueError( + f"Requested n_action_steps={resolved_n_action_steps} exceeds tag action_horizon={int(action_horizon)}." + ) + batch_size = int(inputs["input_ids"].shape[0]) + action_dim_is_pad = self._build_action_dim_is_pad( + action_dim=action_dim, + max_action_dim=int(self.config.max_action_dim), + batch_size=batch_size, + device=self.device, + ) + self.model.action_cuda_graph_manager.set_enabled(enable_cuda_graph) + self.depth_decode_cuda_graph_manager.set_enabled(enable_cuda_graph) + + generated_token_ids = None + depth_bins = None + updated_depth_cache = depth_cache + if inference_action_mode == "continuous": + if enable_depth_reasoning: + latest_first_image = _extract_first_image(images) + depth_prefix = self._generate_depth_prefix( + inputs, + latest_first_image=latest_first_image, + depth_cache=depth_cache, + enable_adaptive_depth=bool(enable_adaptive_depth), + ) + generated_token_ids = depth_prefix.token_ids + depth_bins = depth_prefix.depth_bins + actions = self.model.generate_actions_from_inputs( + input_ids=depth_prefix.full_input_ids, + attention_mask=depth_prefix.attention_mask, + action_dim_is_pad=action_dim_is_pad, + action_horizon=generation_horizon, + num_steps=num_steps, + generator=generator, + encoder_kv_states=depth_prefix.encoder_kv_states, + encoder_attention_mask=self.model._get_encoder_attention_mask( + depth_prefix.full_input_ids, + depth_prefix.attention_mask, + ), + ) + if latest_first_image is not None: + updated_depth_cache = { + "image": latest_first_image, + "depth_bins": depth_bins.detach().cpu().reshape(-1).numpy().astype(np.int64), + } + else: + actions = self.model.generate_actions_from_inputs( + **inputs, + action_dim_is_pad=action_dim_is_pad, + action_horizon=generation_horizon, + num_steps=num_steps, + generator=generator, + ) + else: + if enable_depth_reasoning: + latest_first_image = _extract_first_image(images) + depth_prefix = self._generate_depth_prefix( + inputs, + latest_first_image=latest_first_image, + depth_cache=depth_cache, + enable_adaptive_depth=bool(enable_adaptive_depth), + ) + action_token_ids = self._continue_discrete_generation_from_output( + depth_prefix.next_output, + past_key_values=depth_prefix.past_key_values, + attention_mask=depth_prefix.attention_mask, + end_token_id=self._require_eos_token_id(), + max_steps=max(1, int(generation_horizon * 16)), + ) + generated_token_ids = torch.cat([depth_prefix.token_ids, action_token_ids], dim=1) + depth_bins = depth_prefix.depth_bins + if latest_first_image is not None: + updated_depth_cache = { + "image": latest_first_image, + "depth_bins": depth_bins.detach().cpu().reshape(-1).numpy().astype(np.int64), + } + else: + max_action_decode_steps = max(1, int(generation_horizon * 16)) + action_attention_bias = None + if enable_cuda_graph: + action_static_cache = self._make_ar_decode_static_cache( + inputs, + max_steps=max_action_decode_steps, + ) + action_attention_bias = self._make_depth_decode_attention_bias( + inputs, + action_static_cache, + ) + prefill_output = self( + **inputs, + use_cache=True, + past_key_values=action_static_cache, + ) + else: + prefill_output = self(**inputs, use_cache=True) + action_token_ids = self._continue_discrete_generation_from_output( + prefill_output, + past_key_values=prefill_output.past_key_values, + attention_mask=inputs.get("attention_mask"), + end_token_id=self._require_eos_token_id(), + max_steps=max_action_decode_steps, + attention_bias=action_attention_bias, + ) + generated_token_ids = action_token_ids + actions = self._decode_discrete_action_chunk( + generated_token_ids, + action_tokenizer=action_tokenizer, + action_dim=action_dim, + action_horizon=generation_horizon, + ) + + actions = self._slice_action_dim(actions, action_dim) + actions = self._slice_action_chunk(actions, int(self.config.n_obs_steps), resolved_n_action_steps) + actions = stats.unnormalize_action(actions, norm_tag) + if not torch.is_tensor(actions): + actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32) + else: + actions = actions.to(device=self.device, dtype=torch.float32) + output = MolmoAct2ActionOutput( + actions=actions, + generated_token_ids=generated_token_ids, + depth_bins=depth_bins, + depth_cache=updated_depth_cache, + ) + if return_dict: + return output + return actions + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.Tensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MolmoAct2CausalLMOutputWithPast: + r""" + ```python + >>> from PIL import Image + >>> import requests + >>> from lerobot.policies.molmoact2.hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration + >>> from lerobot.policies.molmoact2.processor_molmoact2 import _load_local_molmoact2_processor + + >>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...") + >>> processor = _load_local_molmoact2_processor("...") + + >>> prompt = "What's the content of the image?" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> messages = [{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}] + + >>> inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=15) + >>> generated_tokens = generated_ids[:, inputs['input_ids'].size(1):] + >>> processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a bustling street scene in what appears to be a Chinatown area. There's ..." + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + image_token_pooling=image_token_pooling, + image_grids=image_grids, + image_num_crops=image_num_crops, + pixel_values_videos=pixel_values_videos, + video_token_pooling=video_token_pooling, + video_grids=video_grids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size) + + return MolmoAct2CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor | None = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + include_visual_inputs = past_key_values is None + if past_key_values is not None and hasattr(past_key_values, "get_seq_length"): + include_visual_inputs = int(past_key_values.get_seq_length()) == 0 + if include_visual_inputs: + model_inputs["pixel_values"] = pixel_values + model_inputs["image_token_pooling"] = image_token_pooling + model_inputs["image_grids"] = image_grids + model_inputs["image_num_crops"] = image_num_crops + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["video_token_pooling"] = video_token_pooling + model_inputs["video_grids"] = video_grids + + return model_inputs + + # Adapted from transformers.models.gemma3.modeling_gemma3 + @staticmethod + def create_masks_for_generate( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + cache_position: torch.Tensor, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + token_type_ids: torch.Tensor | None = None, + **kwargs, + ) -> dict: + # Prepare mask arguments + mask_kwargs = { + "config": config.get_text_config(), + "input_embeds": input_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Add the token type ids mask for generate as well + if token_type_ids is not None and input_embeds.shape[1] != 1: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device) + ) + + return create_masks_for_generate(**mask_kwargs) diff --git a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py new file mode 100644 index 000000000..7b8775faa --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +""" +Processor class for MolmoAct2. +""" + +from typing import Optional, Union +import dataclasses + +import numpy as np + +from transformers.image_utils import ImageInput +from transformers.video_utils import VideoInput +from transformers.processing_utils import ( + Unpack, + ProcessingKwargs, + ProcessorMixin, +) +from transformers.feature_extraction_utils import BatchFeature +from transformers.tokenization_utils_base import TextInput, PreTokenizedInput +from transformers.utils import logging + +from transformers import AutoTokenizer +from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor +from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor + + +logger = logging.get_logger(__name__) + + +# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them +IMAGE_PATCH_TOKEN = f"" # Where to insert high-res tokens +IMAGE_LOW_RES_TOKEN = f"" # Where to insert low-res tokens +IM_START_TOKEN = f"" +LOW_RES_IMAGE_START_TOKEN = f"" +FRAME_START_TOKEN = f"" +IM_END_TOKEN = f"" +FRAME_END_TOKEN = f"" +IM_COL_TOKEN = f"" +IMAGE_PROMPT = "<|image|>" +VIDEO_PROMPT = "<|video|>" + +IMAGE_TOKENS = [ + IMAGE_PATCH_TOKEN, + IM_COL_TOKEN, + IM_START_TOKEN, + LOW_RES_IMAGE_START_TOKEN, + FRAME_START_TOKEN, + IM_END_TOKEN, + FRAME_END_TOKEN, + IMAGE_LOW_RES_TOKEN, +] + + +class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False): + """MolmoAct2 processor kwargs""" + + images_kwargs: MolmoAct2ImagesKwargs + videos_kwargs: MolmoAct2VideoProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": True, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +class MolmoAct2Processor(ProcessorMixin): + attributes = ["image_processor", "video_processor", "tokenizer"] + optional_attributes = [ + "chat_template", + "time_mode", + "image_use_col_tokens", + "use_single_crop_col_tokens", + "use_single_crop_start_token", + "video_use_col_tokens", + "use_frame_special_tokens", + ] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor: MolmoAct2ImageProcessor = None, + video_processor: MolmoAct2VideoProcessor = None, + tokenizer: AutoTokenizer = None, + chat_template: str | None = None, + image_use_col_tokens: bool | None = True, + use_single_crop_col_tokens: bool | None = None, + use_single_crop_start_token: bool | None = True, + video_use_col_tokens: bool | None = False, + use_frame_special_tokens: bool | None = True, + **kwargs, + ) -> None: + super().__init__( + image_processor, + video_processor, + tokenizer, + chat_template=chat_template, + ) + self.image_use_col_tokens = image_use_col_tokens + self.use_single_crop_col_tokens = use_single_crop_col_tokens + self.use_single_crop_start_token = use_single_crop_start_token + self.video_use_col_tokens = video_use_col_tokens + self.use_frame_special_tokens = use_frame_special_tokens + + self.image_placeholder_token = IMAGE_PROMPT + self.video_placeholder_token = VIDEO_PROMPT + self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS] + + def get_image_tokens(self, image_grid: np.ndarray): + resized_h, resized_w, height, width = image_grid + if int(height) == 0 or int(width) == 0: + per_row = np.full(resized_w, IMAGE_PATCH_TOKEN) + use_single_crop_col_tokens = ( + self.image_use_col_tokens + if self.use_single_crop_col_tokens is None + else self.use_single_crop_col_tokens + ) + if use_single_crop_col_tokens: + per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0) + joint = [ + [IM_START_TOKEN], + np.tile(per_row, [resized_h]), + [IM_END_TOKEN], + ] + return np.concatenate(joint) + per_row = np.full(width, IMAGE_PATCH_TOKEN) + if self.image_use_col_tokens: + per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0) + joint = [ + [IM_START_TOKEN], + np.tile(per_row, [height]), + [IM_END_TOKEN], + ] + per_row = np.full(resized_w, IMAGE_PATCH_TOKEN) + use_single_crop_col_tokens = ( + self.image_use_col_tokens + if self.use_single_crop_col_tokens is None + else self.use_single_crop_col_tokens + ) + image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN + if use_single_crop_col_tokens: + per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0) + joint = [ + [image_start_token], + np.tile(per_row, [resized_h]), + [IM_END_TOKEN], + ] + joint + + return np.concatenate(joint) + + def get_video_string( + self, + video_grid: np.ndarray, + timestamps: np.ndarray, + ): + if self.use_frame_special_tokens: + start_token_id = FRAME_START_TOKEN + end_token_id = FRAME_END_TOKEN + else: + start_token_id = IM_START_TOKEN + end_token_id = IM_END_TOKEN + + num_frames, h, w = video_grid + video_string: str = "" + for frame_idx, frame_time in enumerate(timestamps): + # `per-frame-compact` time mode + prev_space = " " if frame_idx > 0 else "" + frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens + + video_string += frame_prefix + per_row = np.full(w, IMAGE_PATCH_TOKEN) + if self.video_use_col_tokens: + per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0) + extra_tokens = np.tile(per_row, [h]) + video_tokens = [ + [start_token_id], + extra_tokens, + [end_token_id], + ] + video_string += "".join(np.concatenate(video_tokens, 0)) + + return video_string + + def insert_bos( + self, + input_ids: np.ndarray, + attention_mask: np.ndarray, + bos_token_id: int, + pad_token_id: int, + ): + """ + Args: + input_ids: [B, S] array with left padding + attention_mask: [B, S] array (0 for pad, 1 for valid) + bos_token_id: int + pad_token_id: int + Returns: + input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed + attention_mask_out: same shape as input_ids_out + """ + + need_to_expand = len(input_ids.shape) == 1 + if need_to_expand: + input_ids = input_ids[None, :] + attention_mask = attention_mask[None, :] + + B, S = input_ids.shape + + # Handle zero-length sequence + if S == 0: + new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype) + new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype) + if need_to_expand: + new_input_ids = new_input_ids[0] + new_attention_mask = new_attention_mask[0] + return new_input_ids, new_attention_mask + + first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B] + bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id) + + if bos_already_present: + if need_to_expand: + input_ids = input_ids[0] + attention_mask = attention_mask[0] + return input_ids, attention_mask + else: + new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype) + new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype) + + src_idx = np.tile(np.arange(S), (B, 1)) # [B, S] + valid_mask = src_idx >= first_valid_index[:, None] # [B, S] + tgt_idx = src_idx + 1 # shit right + batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S] + + # flatten valid_positions + flat_vals = input_ids[valid_mask] + flat_batch = batch_idx[valid_mask] + flat_tgt = tgt_idx[valid_mask] + + new_input_ids[flat_batch, flat_tgt] = flat_vals + new_attention_mask[flat_batch, flat_tgt] = 1 + + insert_pos = first_valid_index + new_input_ids[np.arange(B), insert_pos] = bos_token_id + new_attention_mask[np.arange(B), insert_pos] = 1 + + if need_to_expand: + new_input_ids = new_input_ids[0] + new_attention_mask = new_attention_mask[0] + + return new_input_ids, new_attention_mask + + def __call__( + self, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + images: ImageInput = None, + videos: VideoInput = None, + **kwargs: Unpack[MolmoAct2ProcessorKwargs], + ) -> BatchFeature: + """ + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + videos (`dict[str, Any]` or `list[dict[str, Any]]`): + The video or batch of videos to be prepared. Each video can be a dictionary with the following keys: + - `"frames"`: `np.ndarray` of shape (T, H, W, 3) + - `"timestamps"`: `np.ndarray` of shape (T,) + - `"sampled_fps"`: `float` (optional) + - `"sampling_augmentation"`: `str` (optional) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + `BatchFeature`: A [`BatchFeature`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`. + Returned when `images` is not `None`. + - **image_grids** -- Grids of images. Returned when `images` is not `None`. + - **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`. + Returned when `videos` is not `None`. + - **video_grids** -- Grids of videos. Returned when `videos` is not `None`. + """ + + output_kwargs = self._merge_kwargs( + MolmoAct2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + image_grids = image_inputs["image_grids"] + else: + image_inputs = {} + image_grids = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grids = videos_inputs["video_grids"] + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + else: + videos_inputs = {} + video_grids = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + + if image_grids is not None: + index = 0 + for i in range(len(text)): + num_images = text[i].count(self.image_placeholder_token) + image_grids_i = image_grids[index : index + num_images] + for image_grid in image_grids_i: + image_tokens = self.get_image_tokens(image_grid) + image_string = "".join(image_tokens) + text[i] = text[i].replace(self.image_placeholder_token, image_string, 1) + index += num_images + + if video_grids is not None: + index = 0 + for i in range(len(text)): + num_videos = text[i].count(self.video_placeholder_token) + assert num_videos in {0, 1}, "At most one video is supported for now" + video_grids_i = video_grids[index : index + num_videos] + metadata_i = video_metadata[index : index + num_videos] + for video_grid, metadata in zip(video_grids_i, metadata_i): + video_string = self.get_video_string( + video_grid, + metadata.timestamps, + ) + text[i] = text[i].replace(self.video_placeholder_token, video_string, 1) + index += num_videos + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + input_ids = text_inputs["input_ids"] + attention_mask = text_inputs["attention_mask"] + + input_ids = np.array(input_ids) + attention_mask = np.array(attention_mask) + + bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id + input_ids, attention_mask = self.insert_bos( + input_ids, attention_mask, bos, self.tokenizer.pad_token_id + ) + + if return_mm_token_type_ids: + image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype) + token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1) + text_inputs["token_type_ids"] = token_type_ids.tolist() + + text_inputs["input_ids"] = input_ids.tolist() + text_inputs["attention_mask"] = attention_mask.tolist() + + return BatchFeature( + data={**text_inputs, **image_inputs, **videos_inputs}, + tensor_type=return_tensors, + ) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + +MolmoAct2Processor.register_for_auto_class() diff --git a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py new file mode 100644 index 000000000..644d5a691 --- /dev/null +++ b/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py @@ -0,0 +1,997 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +"""Video processor class for MolmoAct2""" + +from functools import partial +import os +import warnings +from contextlib import redirect_stdout +from io import BytesIO +from urllib.parse import urlparse +from typing import Optional, Union +from collections.abc import Callable + +import numpy as np +import requests +import einops +import torch +import torchvision.transforms + +from transformers.image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, + validate_kwargs, +) +from transformers.video_utils import ( + VideoInput, + is_valid_video, + make_batched_videos, + make_batched_metadata, + VideoMetadata, +) +from transformers.processing_utils import Unpack, VideosKwargs +from transformers.video_processing_utils import BaseVideoProcessor +from transformers.utils import logging +from transformers.feature_extraction_utils import BatchFeature +from transformers.utils import ( + is_av_available, + is_decord_available, + is_torchcodec_available, + is_yt_dlp_available, + TensorType, + logging, + to_numpy, +) + + +logger = logging.get_logger(__name__) + +MAX_VIDEO_FPS = 8 + + +def normalize_image( + image: np.ndarray, + image_mean: list[float], + image_std: list[float], +) -> np.ndarray: + if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]): + return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32) + image -= np.array(image_mean, dtype=np.float32)[None, None, :] + image /= np.array(image_std, dtype=np.float32)[None, None, :] + return image + + +def resize_image( + image: np.ndarray, + desired_output_size: list[int], + resample: PILImageResampling, +) -> np.ndarray: + if len(image.shape) == 3: + is_video = False + image = torch.permute(torch.from_numpy(image), [2, 0, 1]) + else: + is_video = True + image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2]) + dtype = image.dtype + if torch.is_floating_point(image): + in_min = 0.0 + in_max = 1.0 + resized = torchvision.transforms.Resize( + desired_output_size, + resample, + antialias=False, + )(image) + resized = torch.clip(resized, 0.0, 1.0).to(dtype) + else: + assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format( + image.dtype + ) + in_min = 0.0 + in_max = 255.0 + resized = torchvision.transforms.Resize( + desired_output_size, + resample, + antialias=False, + )(image) + resized = torch.clip(resized, 0, 255).to(dtype) + + resized = resized.to(torch.float32) + resized = (resized - in_min) / (in_max - in_min) + + if is_video: + resized = torch.permute(resized, [0, 2, 3, 1]).numpy() + else: + resized = torch.permute(resized, [1, 2, 0]).numpy() + + return resized + + +def build_resized_image( + image: np.ndarray, + base_image_input_size: list[int], + resample: PILImageResampling, + image_mean: list[float], + image_std: list[float], + image_patch_size: int, +) -> tuple[np.ndarray, np.ndarray]: + resized = resize_image( + image, + base_image_input_size, + resample, + ) + resized = normalize_image(resized, image_mean, image_std) + if len(resized.shape) == 3: + resized = np.expand_dims(resized, 0) + crop_patch_w = base_image_input_size[1] // image_patch_size + crop_patch_h = base_image_input_size[0] // image_patch_size + resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w]) + return resized, resize_idx + + +def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray: + """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]""" + if len(array.shape) == 3: + n_crops, h, w = array.shape + h_patches = h // patch_size + w_patches = w // patch_size + array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size]) + array = np.transpose(array, [0, 1, 3, 2, 4]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size]) + return array + else: + n_crops, h, w, c = array.shape + h_patches = h // patch_size + w_patches = w // patch_size + array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c]) + array = np.transpose(array, [0, 1, 3, 2, 4, 5]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c]) + return array + + +def arange_for_pooling( + idx_arr: np.ndarray, + pool_h: int, + pool_w: int, +) -> np.ndarray: + h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0] + w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1] + idx_arr = np.pad( + idx_arr, + [[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]], + mode="constant", + constant_values=-1, + ) + return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) + + +def image_to_patches_and_grids( + image: ImageInput, + base_image_input_size: list[int], + resample: PILImageResampling, + image_mean: list[float], + image_std: list[float], + image_patch_size: int, + image_pooling_w: int, + image_pooling_h: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + :return image_grids, the shape of each image after pooling + :return crops, the image crops to processes with the ViT + :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the + patches in `crops` to pool for that token, masked with -1 + """ + if isinstance(base_image_input_size, int): + base_image_input_size = (base_image_input_size, base_image_input_size) + + pooling_w = image_pooling_w + pooling_h = image_pooling_h + + resized, resize_idx = build_resized_image( + image, + base_image_input_size, + resample, + image_mean, + image_std, + image_patch_size, + ) + pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) + h, w = pooling_idx.shape[:2] + pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w]) + image_grid = [h, w] + return ( + image_grid, + batch_pixels_to_patches(resized, image_patch_size), + pooling_idx, + ) + + +def get_candidate_target_fps( + video_fps: int | float, + sampling_fps: int | float, + max_fps: int | float = MAX_VIDEO_FPS, +) -> list[float]: + """ + Return the subset of `video_fps` factors that remain multiples of `sampling_fps`. + + Examples: + >>> get_candidate_target_fps(video_fps=6, sampling_fps=2) + [2, 6] + >>> get_candidate_target_fps(video_fps=5, sampling_fps=1) + [1, 5] + >>> get_candidate_target_fps(video_fps=2, sampling_fps=2) + [2] + >>> get_candidate_target_fps(video_fps=5, sampling_fps=2) + Traceback (most recent call last): + ... + ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps. + """ + video_fps = int(video_fps) + sampling_fps = int(sampling_fps) + max_fps = int(max_fps) + + if sampling_fps is None: + raise ValueError("sampling_fps must be provided") + if video_fps <= 0 or sampling_fps <= 0: + raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})") + if video_fps % sampling_fps != 0: + raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.") + + candidates = [] + for candidate in range(sampling_fps, video_fps + 1, sampling_fps): + if candidate > max_fps: + break + if video_fps % candidate == 0: + candidates.append(float(candidate)) + + return candidates + + +def read_video_decord( + video_path, + sample_timestamps_fn: Callable, + **kwargs, +) -> np.ndarray: + """ + Decode a video using the Decord backend. + + Args: + video_path (`str`): + Path to the video file. + sample_timestamps_fn (`Callable`): + A callable function that will return timestamps at which the video should be sampled. + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import from decord + import importlib + + decord = importlib.import_module("decord") + + vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu + video_fps = vr.get_avg_fps() + total_num_frames = len(vr) + time_stamps = vr.get_frame_timestamp(list(range(len(vr)))) + duration = time_stamps[-1][1] - time_stamps[0][0] + + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="decord", + ) + + target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs) + target_timestamps = np.array(target_timestamps) + offset = time_stamps[0, 0] + + ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right") + ix = np.minimum(ix, len(time_stamps) - 1) + + video = vr.get_batch(ix).asnumpy() + metadata.update( + { + "frames_indices": target_timestamps * video_fps, + "height": video.shape[1], + "width": video.shape[2], + } + ) + return video, metadata + + +def read_video_torchcodec( + video_path, + sample_timestamps_fn: Callable, + **kwargs, +) -> np.ndarray: + """ + Decode a video using torchcodec decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_timestamps_fn (`Callable`): + A callable function that will return timestamps at which the video should be sampled. + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import torchcodec + import importlib + + torchcodec = importlib.import_module("torchcodec") + + decoder = torchcodec.decoders.VideoDecoder( + video_path, + # Interestingly `exact` mode takes less than approximate when we load the whole video + seek_mode="exact", + # Allow FFmpeg decide on the number of threads for efficiency + num_ffmpeg_threads=0, + ) + # If the first frame starts at > 0, we effectively clip the video starting at that time + # since (most) video players would also skip to that time + time_offset = decoder.metadata.begin_stream_seconds_from_content + # Note this duration does assume we started playing at `time_offset` + duration = decoder.metadata.duration_seconds + + metadata = VideoMetadata( + total_num_frames=decoder.metadata.num_frames, + fps=decoder.metadata.average_fps, + duration=duration, + video_backend="torchcodec", + height=decoder.metadata.height, + width=decoder.metadata.width, + ) + + target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs) + + # Floating point/rounding issues might cause `target_timestamps` to be very slightly + # out-of-bounds, to handle this we sanity check then clip them + assert all(x >= 0 for x in target_timestamps) + assert all(x < duration + 1e-6 for x in target_timestamps) + # 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the + # exact boundary value, we should still get the first/last frame anyway + max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6 + min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6 + # Note we avoid using numpy ops here to reduce floating precision issues + timestamps = [x + time_offset for x in target_timestamps] + timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps] + + video = ( + decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) + ) # Convert to THWC format + target_timestamps = np.array(target_timestamps) + metadata.frames_indices = target_timestamps * metadata.fps + + return video, metadata + + +def read_video_pyav( + video_path, + sample_timestamps_fn: Callable, + **kwargs, +) -> np.ndarray: + """ + Decode a video using the PyAV backend. + + Args: + video_path (`str`): + Path to the video file. + sample_timestamps_fn (`Callable`): + A callable function that will return timestamps at which the video should be sampled. + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import torchcodec + import importlib + + av = importlib.import_module("av") + + with av.open(video_path) as container: + video_stream = container.streams.video[0] + fps = video_stream.average_rate or video_stream.guessed_rate + it = container.decode(video=0) + frames = list(it) + + stream = container.streams.video[0] + start = frames[0].pts * stream.time_base + container_end = stream.duration + if container_end is not None: + container_end *= stream.time_base + if container_end is None or container_end < frames[-1].pts: + # Some problem with stream duration, so use the frame PTS directly + # and guess the duration of the last frame + end = frames[-1].pts * stream.time_base + 1 / fps + else: + end = container_end + duration = float(end - start) + + metadata = VideoMetadata( + total_num_frames=len(frames), + fps=float(fps), + duration=float(duration), + video_backend="pyav", + height=video_stream.height, + width=video_stream.width, + ) + + target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs) + offset = float(start) + + target_timestamps = np.array(target_timestamps) + end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration]) + indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right") + indices = np.minimum(indices, len(end_time_stamps) - 1) + + video = np.stack( + [frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices], + axis=0, + ) + + metadata.frames_indices = target_timestamps * fps + + return video, metadata + + +VIDEO_DECODERS = { + "decord": read_video_decord, + "torchcodec": read_video_torchcodec, + "pyav": read_video_pyav, +} + + +def load_video( + video: VideoInput, + backend: str = "decord", + sample_timestamps_fn: Callable | None = None, + **kwargs, +): + """ + Loads `video` to a numpy array. + + Args: + video (`VideoInput`): + The video to convert to the numpy array format. Can be a link to video or local path. + backend (`str`, *optional*, defaults to `"decord"`): + The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord". + sample_timestamps_fn (`Callable`): + A callable function that will return timestamps at which the video should be sampled. + """ + + # Early exit if provided an array or `PIL` frames + if not isinstance(video, str): + metadata = [None] * len(video) + return video, metadata + + if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]: + if not is_yt_dlp_available(): + raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.") + # Lazy import from yt_dlp + import importlib + + yt_dlp = importlib.import_module("yt_dlp") + + buffer = BytesIO() + with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f: + f.download([video]) + bytes_obj = buffer.getvalue() + file_obj = BytesIO(bytes_obj) + elif video.startswith("http://") or video.startswith("https://"): + file_obj = BytesIO(requests.get(video, timeout=10).content) + elif os.path.isfile(video): + file_obj = video + else: + raise TypeError( + "Incorrect format used for video. Should be an url linking to an video or a local path." + ) + + # can also load with decord, but not cv2/torchvision + # both will fail in case of url links + video_is_url = video.startswith("http://") or video.startswith("https://") + if video_is_url and backend == "opencv": + raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend") + + if ( + (not is_decord_available() and backend == "decord") + or (not is_torchcodec_available() and backend == "torchcodec") + or (not is_av_available() and backend == "pyav") + ): + raise ImportError( + f"You chose backend={backend} for loading the video but the required library is not found in your environment " + f"Make sure to install {backend} before loading the video." + ) + + video_decoder = VIDEO_DECODERS[backend] + video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs) + return video, metadata + + +def get_target_fps( + video_fps: float, + max_frames: int, + total_frames: int, + frame_sample_mode: str, + candidate_target_fps: tuple[float], +) -> float: + """ + Get the target fps that best spans the video and has the most frames sampled + """ + num_frames_sampled = 0 + selected_target_fps = None + for target_fps in candidate_target_fps: + step_size = max(int(video_fps / target_fps), 1) + num_frames_sampled_at_fps = int(total_frames / step_size) + if num_frames_sampled == 0: + if "uniform" in frame_sample_mode: + if num_frames_sampled_at_fps > max_frames: + break + selected_target_fps = target_fps + num_frames_sampled = num_frames_sampled_at_fps + + else: + # the candidate sampling fps increases so frame count can't decrease + assert num_frames_sampled <= num_frames_sampled_at_fps + if num_frames_sampled_at_fps > max_frames: + # choose the sampling fps that spans the video + continue + + elif num_frames_sampled_at_fps > num_frames_sampled: + # both are less than max_frames, choose the one with higher density of frames sampled + selected_target_fps = target_fps + num_frames_sampled = num_frames_sampled_at_fps + return selected_target_fps + + +def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps): + if selected_target_fps is None: + frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int) + else: + step_size = max(int(video_fps / selected_target_fps), 1) + frame_indices = np.arange(0, total_frames, step_size) + if len(frame_indices) > max_frames: + frame_indices = frame_indices[:max_frames] + return selected_target_fps, frame_indices + + +class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False): + patch_size: int | None + pooling_size: list[int] | None + frame_sample_mode: str | None + max_fps: int | None + sampling_fps: int | None + + +class MolmoAct2VideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BILINEAR + size = {"height": 378, "width": 378} + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + patch_size = 14 + pooling_size = [3, 3] + do_sample_frames = True + frame_sample_mode = "uniform_last_frame" + max_fps = 2 + sampling_fps = 2 + valid_kwargs = MolmoAct2VideoProcessorKwargs + model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"] + + def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]): + super().__init__(**kwargs) + if self.size is not None and ( + self.size.get("height", None) is None or self.size.get("width", None) is None + ): + raise ValueError("size must contain 'height' and 'width' keys.") + + def _further_process_kwargs( + self, + size: SizeDict | None = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if size is not None and ("height" not in size or "width" not in size): + raise ValueError("size must contain 'height' and 'width' keys.") + + return super()._further_process_kwargs(size=size, **kwargs) + + def sample_times( + self, + metadata: VideoMetadata, + frame_sample_mode: str, + num_frames: int, + max_fps: int | None = None, + sampling_fps: int | None = None, + **kwargs, + ) -> np.ndarray: + """ + Time-based sampling if an array video is passed + Args: + metadata (`VideoMetadata`): + Metadata of the video containing information about total duration, fps and total number of frames. + frame_sample_mode (`str`, *optional*): + Mode to sample frames. Defaults to `self.frame_sample_mode`. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + man_fps (`int`, *optional*): + Maximum frames per second to sample. + sampling_fps (`int`, *optional*): + Sampling frames per second. Defaults to `self.sampling_fps`. + Used when `frame_sample_mode` is `"fps"`. + """ + frame_sample_mode = frame_sample_mode or self.frame_sample_mode + num_frames = num_frames or self.num_frames + sampling_fps = sampling_fps or self.sampling_fps + + duration = metadata.duration or metadata.total_num_frames / metadata.fps + if frame_sample_mode == "fps": + candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps) + # Try larger and larger FPSs until we hit one that can't span the video + target_fps = candidate_target_fps[0] + for candidate_fps in candidate_target_fps[1:]: + if num_frames / candidate_fps < duration: + break + target_fps = candidate_fps + times = np.arange(0, num_frames) / target_fps + times = times[times < duration] + return times + elif frame_sample_mode == "uniform_last_frame": + if max_fps is not None: + max_duration = (num_frames - 1) / max_fps # -1 to include the last frame + if max_duration < duration: + times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64) + else: + times = np.arange(0.0, stop=duration, step=1 / max_fps) + times = np.concatenate([times, [duration]], axis=0) + assert len(times) <= num_frames + else: + times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64) + return times + else: + raise NotImplementedError(frame_sample_mode) + + def sample_frames( + self, + metadata: VideoMetadata, + frame_sample_mode: str | None = None, + num_frames: int | None = None, + max_fps: int | None = None, + sampling_fps: int | None = None, + **kwargs, + ) -> np.ndarray: + """ + Frame-based sampling if an array video is passed + Args: + metadata (`VideoMetadata`): + Metadata of the video containing information about total duration, fps and total number of frames. + frame_sample_mode (`str`, *optional*): + Mode to sample frames. Defaults to `self.frame_sample_mode`. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + max_fps (`int`, *optional*): + Maximum frames per second to sample. + sampling_fps (`int`, *optional*): + Sampling frames per second. Defaults to `self.sampling_fps`. + Used when `frame_sample_mode` is `"fps"`. + """ + frame_sample_mode = frame_sample_mode or self.frame_sample_mode + num_frames = num_frames or self.num_frames + sampling_fps = sampling_fps or self.sampling_fps + + total_num_frames = metadata.total_num_frames + if frame_sample_mode == "uniform_last_frame" and max_fps is not None: + duration = total_num_frames / metadata.fps + if total_num_frames <= 2: + return np.arange(total_num_frames).astype(int) + if duration > (num_frames - 1) / max_fps: # -1 to include the last frame + # uniform fallback + indices = np.linspace( + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, + ).astype(int) + return indices + else: + float_indices = np.arange( + 0.0, + stop=total_num_frames - 1, + step=float(metadata.fps / max_fps), + ) + if np.round(float_indices[-1]) != total_num_frames - 1: + float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0) + indices = np.round(float_indices).astype(int) + assert indices[-1] < total_num_frames + assert len(float_indices) <= num_frames + return indices + elif frame_sample_mode == "uniform_last_frame": + indices = np.linspace( + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, + ).astype(int) + return indices + elif frame_sample_mode == "fps": + candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps) + selected_target_fps = get_target_fps( + metadata.fps, + num_frames, + total_num_frames, + frame_sample_mode, + candidate_target_fps, + ) + _, indices = get_frame_times_and_chosen_fps( + selected_target_fps, + total_num_frames, + num_frames, + metadata.fps, + ) + return indices + else: + raise NotImplementedError(frame_sample_mode) + + def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None): + """ + Convert a single or a list of urls into the corresponding `np.array` objects. + + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()): + raise ImportError( + "MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed." + ) + + if is_decord_available(): + backend = "decord" + elif is_torchcodec_available(): + warnings.warn( + "`decord` is not installed and cannot be used to decode the video by default. " + "Falling back to `torchcodec`." + ) + backend = "torchcodec" + else: + warnings.warn( + "`decord` is not installed and cannot be used to decode the video by default. " + "Falling back to `PyAV`." + ) + backend = "pyav" + + if isinstance(video_url_or_urls, list): + return list( + zip( + *[ + self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) + for x in video_url_or_urls + ] + ) + ) + else: + return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn) + + def _decode_and_sample_videos( + self, + videos: VideoInput, + video_metadata: VideoMetadata | dict, + do_sample_frames: bool | None = None, + sample_indices_fn: Callable | None = None, + sample_timestamps_fn: Callable | None = None, + ): + """ + Decode input videos and sample frames if needed. + """ + videos = make_batched_videos(videos) + video_metadata = make_batched_metadata(videos, video_metadata=video_metadata) + + # Framed-based sampling if an array video is passed + # Otherwise, time-based sampling with decoding + if is_valid_video(videos[0]) and do_sample_frames: + assert video_metadata[0].fps is not None, "FPS must be provided for video input" + sampled_videos = [] + sampled_metadata = [] + for video, metadata in zip(videos, video_metadata): + indices = sample_indices_fn(metadata=metadata) + metadata.frames_indices = indices + sampled_videos.append(video[indices]) + sampled_metadata.append(metadata) + videos = sampled_videos + video_metadata = sampled_metadata + elif not is_valid_video(videos[0]): + if sample_indices_fn is None: + logger.warning( + "do_sample_frames is False, but video array is not provided: " + "Will decode the video and sample frames using MolmoAct2's default sampling mode" + ) + if isinstance(videos[0], list): + raise ValueError("A list of images is not supported for video input!") + else: + videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn) + + return videos, video_metadata + + def _prepare_input_videos( + self, + videos: VideoInput, + **kwargs, + ) -> list[np.ndarray]: + processed_videos = [to_numpy(video) for video in videos] + return processed_videos + + def preprocess( + self, + videos: VideoInput, + **kwargs: Unpack[MolmoAct2VideoProcessorKwargs], + ) -> BatchFeature: + validate_kwargs( + captured_kwargs=kwargs.keys(), + valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"], + ) + + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + do_sample_frames = kwargs.pop("do_sample_frames") + video_metadata = kwargs.pop("video_metadata") + + sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None + sample_timestamps_fn = partial(self.sample_times, **kwargs) + videos, video_metadata = self._decode_and_sample_videos( + videos, + video_metadata=video_metadata, + do_sample_frames=do_sample_frames, + sample_indices_fn=sample_indices_fn, + sample_timestamps_fn=sample_timestamps_fn, + ) + videos = self._prepare_input_videos(videos=videos) + + kwargs = self._further_process_kwargs(**kwargs) + + return_metadata = kwargs.pop("return_metadata") + preprocessed_videos = self._preprocess(videos=videos, **kwargs) + if return_metadata: + preprocessed_videos["video_metadata"] = video_metadata + return preprocessed_videos + + def _preprocess( + self, + videos: list[np.ndarray], + size: SizeDict | None = None, + resample: PILImageResampling | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool | None = None, + patch_size: int | None = None, + pooling_size: list[int] | None = None, + return_tensors: str | TensorType | None = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess a video for the model. + Args: + videos (`VideoInput`): + Video to preprocess. + size (`SizeDict`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`): + The pooling size of the vision adapter. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + + Returns: + A `BatchFeature` containing the following keys: + - `pixel_values_videos`: The preprocessed videos. + - `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`. + - `video_grids`: The video grids. + """ + if size.height is None or size.width is None: + raise ValueError("size must contain 'height' and 'width' keys.") + + base_image_input_size = [size.height, size.width] + + resample = resample or self.resample + image_mean = image_mean or self.image_mean + image_std = image_std or self.image_std + do_convert_rgb = do_convert_rgb or self.do_convert_rgb + + patch_size = patch_size or self.patch_size + pooling_size = pooling_size or self.pooling_size + + image_pooling_h, image_pooling_w = pooling_size + + batch_grids = [] + batch_crops = [] + batch_pooled_patches_idx = [] + + for video in videos: + all_crops = [] + pooled_patches_idx = [] + + for frame in video: + image_grid, crops, pooled_idx = image_to_patches_and_grids( + frame, + base_image_input_size, + resample, + image_mean, + image_std, + patch_size, + image_pooling_w, + image_pooling_h, + ) + offset = sum(np.prod(x.shape[:2]) for x in all_crops) + pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx) + pooled_patches_idx.append(pooled_idx_with_offset) + all_crops.append(crops) + + video_grid = np.array([len(video), image_grid[0], image_grid[1]]) + all_crops = np.concatenate(all_crops, 0) + pooled_patches_idx = np.concatenate(pooled_patches_idx, 0) + + batch_grids.append(video_grid) + batch_crops.append(all_crops) + batch_pooled_patches_idx.append(pooled_patches_idx) + + video_grids = np.stack(batch_grids, 0) + pixel_values_videos = np.concatenate(batch_crops, 0) + video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0) + + data = dict( + pixel_values_videos=pixel_values_videos, + video_token_pooling=video_token_pooling, + video_grids=video_grids, + ) + + return BatchFeature(data, tensor_type=return_tensors) + + +MolmoAct2VideoProcessor.register_for_auto_class() diff --git a/src/lerobot/policies/molmoact2/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/modeling_molmoact2.py new file mode 100644 index 000000000..f86be0904 --- /dev/null +++ b/src/lerobot/policies/molmoact2/modeling_molmoact2.py @@ -0,0 +1,1551 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import types +from collections import deque +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from safetensors.torch import load_file as load_safetensors_file +from torch import Tensor +from torch.distributions import Beta + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION +from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package + +from ..rtc.modeling_rtc import RTCProcessor +from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location + +if TYPE_CHECKING or _transformers_available: + from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + + from .hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config + from .hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration +else: + SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + SAFE_WEIGHTS_NAME = "model.safetensors" + HFMolmoAct2Config = None + MolmoAct2ForConditionalGeneration = None + +if TYPE_CHECKING or (_transformers_available and _scipy_available): + from .hf_model.action_tokenizer import UniversalActionProcessor +else: + UniversalActionProcessor = None + +_MODEL_INPUT_KEYS = { + "input_ids", + "pixel_values", + "image_token_pooling", + "image_grids", + "image_num_crops", + "pixel_values_videos", + "video_token_pooling", + "video_grids", + "attention_mask", + "position_ids", + "past_key_values", + "token_type_ids", + "inputs_embeds", +} + + +def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location: str) -> None: + index_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_INDEX_NAME) + single_file_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_NAME) + if os.path.isfile(index_path): + with open(index_path, encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] + loaded_keys = set(weight_map) + model_keys = set(model.state_dict()) + missing_keys = sorted(model_keys - loaded_keys) + unexpected_keys = sorted(loaded_keys - model_keys) + if missing_keys or unexpected_keys: + message = ["MolmoAct2 safetensors do not match the local model implementation."] + if missing_keys: + message.append(f"Missing keys: {missing_keys[:8]}") + if unexpected_keys: + message.append(f"Unexpected keys: {unexpected_keys[:8]}") + raise RuntimeError(" ".join(message)) + for shard_file in sorted(set(weight_map.values())): + state_dict = load_safetensors_file(os.path.join(checkpoint_location, shard_file), device="cpu") + model.load_state_dict(state_dict, strict=False) + del state_dict + return + if os.path.isfile(single_file_path): + state_dict = load_safetensors_file(single_file_path, device="cpu") + model.load_state_dict(state_dict, strict=True) + return + raise FileNotFoundError( + f"MolmoAct2 checkpoint at {checkpoint_location} must contain {SAFE_WEIGHTS_NAME} " + f"or {SAFE_WEIGHTS_INDEX_NAME}." + ) + + +def _torch_dtype(dtype: str) -> torch.dtype: + if dtype == "float32": + return torch.float32 + if dtype == "bfloat16": + return torch.bfloat16 + if dtype == "float16": + return torch.float16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +def _sample_beta_timesteps( + *, + batch_size: int, + device: torch.device, + cutoff: float, + time_offset: float, + time_scale: float, + alpha: float, + beta: float, +) -> Tensor: + if cutoff < time_offset: + raise ValueError(f"flow-matching cutoff must be >= time_offset, got {cutoff} < {time_offset}") + if time_scale <= 0: + raise ValueError(f"flow-matching time_scale must be > 0, got {time_scale}") + upper = min(cutoff, time_offset + time_scale) + dist = Beta(torch.tensor(alpha, device=device), torch.tensor(beta, device=device)) + samples = dist.sample((batch_size,)) + scale = upper - time_offset + if scale == 0: + return torch.full((batch_size,), time_offset, device=device, dtype=samples.dtype) + return time_offset + scale * samples + + +class MolmoAct2Policy(PreTrainedPolicy): + config_class = MolmoAct2Config + name = "molmoact2" + + def __init__( + self, + config: MolmoAct2Config, + *inputs, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + dataset_meta: Any | None = None, + **kwargs, + ): + super().__init__(config, *inputs, **kwargs) + self.config.apply_norm_tag_metadata() + self.config.validate_features() + del inputs, kwargs, dataset_stats, dataset_meta + self._checkpoint_action_mode = self.config.saved_policy_action_mode() + self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps) + self._rollout_action_generator: torch.Generator | None = None + self._rollout_task_key: tuple[Any, ...] | None = None + self._rollout_index_for_task = -1 + self.rtc_processor: RTCProcessor | None = None + self.action_tokenizer: Any | None = None + self._load_hf_model() + self.config.validate_inference_action_mode(self._checkpoint_action_mode) + if self.config.enable_lora_vlm: + self._apply_lora_adapters() + self.init_rtc_processor() + + def _load_hf_model(self) -> None: + require_package("transformers", extra="molmoact2") + + checkpoint_location = _resolve_checkpoint_location( + self.config.checkpoint_path, + revision=self.config.checkpoint_revision, + force_download=bool(self.config.checkpoint_force_download), + ) + model_dtype = _torch_dtype(self.config.model_dtype) + if HFMolmoAct2Config is None or MolmoAct2ForConditionalGeneration is None: + raise RuntimeError("transformers is required to load MolmoAct2 checkpoints.") + hf_config = HFMolmoAct2Config.from_pretrained( + checkpoint_location, + token=_hf_token(), + ) + self.model = MolmoAct2ForConditionalGeneration.from_pretrained( + checkpoint_location, + config=hf_config, + dtype=model_dtype, + low_cpu_mem_usage=True, + token=_hf_token(), + ) + # Keep Hub loading limited to local code plus safetensors, and verify the + # local implementation exactly matches the checkpoint key space. + _strict_load_safetensors_weights(self.model, checkpoint_location) + hf_max_action_dim = int(getattr(self.model.config, "max_action_dim", -1)) + if hf_max_action_dim != int(self.config.expected_max_action_dim): + raise ValueError( + "MolmoAct2 checkpoint max_action_dim mismatch: " + f"checkpoint={hf_max_action_dim}, expected={self.config.expected_max_action_dim}." + ) + if hf_max_action_dim != 32: + raise ValueError( + f"MolmoAct2 released checkpoints must have max_action_dim=32, got {hf_max_action_dim}." + ) + + if not hasattr(self.model.config, "max_action_horizon"): + raise ValueError("MolmoAct2 HF checkpoints must define `max_action_horizon`.") + self._override_loaded_max_action_horizon(int(self.config.chunk_size)) + + if not hasattr(self.model.config, "action_mode"): + raise ValueError( + "MolmoAct2 HF checkpoints must define `action_mode`. If this is a released " + "MolmoAct2 checkpoint, refresh the local Hub cache with " + "`policy.checkpoint_force_download=true` after the updated files are pushed." + ) + checkpoint_action_mode = str(self.model.config.action_mode) + self.config.validate_checkpoint_action_mode( + checkpoint_action_mode, + has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)), + ) + + if self.config.freeze_embedding: + self._freeze_input_embeddings() + if self.config.train_action_expert_only: + self._freeze_non_action_expert_parameters() + if self.config.gradient_checkpointing: + self._enable_gradient_checkpointing() + self.train(self.training) + + def reset(self) -> None: + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._rollout_action_generator = None + + def _set_inference_cuda_graph_enabled(self, enabled: bool) -> None: + if not hasattr(self, "model"): + return + hf_model = self._hf_model() + enabled = bool(enabled and getattr(self.config, "enable_inference_cuda_graph", True)) + managers = [ + getattr(self._backbone(), "action_cuda_graph_manager", None), + getattr(hf_model, "action_cuda_graph_manager", None), + getattr(hf_model, "depth_decode_cuda_graph_manager", None), + ] + seen: set[int] = set() + for manager in managers: + if manager is None or id(manager) in seen: + continue + seen.add(id(manager)) + set_enabled = getattr(manager, "set_enabled", None) + if callable(set_enabled): + set_enabled(enabled) + + def init_rtc_processor(self) -> None: + self.rtc_processor = None + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _action_expert(self) -> torch.nn.Module: + return self._backbone()._require_action_expert() + + def _enable_gradient_checkpointing(self) -> None: + enable_gradient_checkpointing = getattr(self._hf_model(), "gradient_checkpointing_enable", None) + if callable(enable_gradient_checkpointing): + try: + enable_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": False}) + except TypeError: + enable_gradient_checkpointing() + else: + transformer = getattr(self._backbone(), "transformer", None) + if transformer is None: + raise RuntimeError("gradient_checkpointing=true, but MolmoAct2 exposes no text transformer.") + transformer.gradient_checkpointing = True + + transformer = getattr(self._backbone(), "transformer", None) + if transformer is not None: + transformer.gradient_checkpointing = True + vision_backbone = getattr(self._backbone(), "vision_backbone", None) + if vision_backbone is not None: + vision_backbone.gradient_checkpointing = True + + def _freeze_non_action_expert_parameters(self) -> None: + trainable_params = 0 + for name, param in self.named_parameters(): + param.requires_grad = "action_expert" in name + if param.requires_grad: + trainable_params += param.numel() + if trainable_params == 0: + raise RuntimeError("train_action_expert_only=true, but no action_expert parameters were found.") + + def _unfreeze_action_expert_parameters(self) -> None: + trainable_params = 0 + for name, param in self.named_parameters(): + if "action_expert" in name: + param.requires_grad_(True) + trainable_params += param.numel() + if trainable_params == 0: + raise RuntimeError("enable_lora_vlm=true, but no action_expert parameters were found.") + + def train(self, mode: bool = True): + super().train(mode) + if getattr(self.config, "train_action_expert_only", False) and hasattr(self, "model"): + self._hf_model().eval() + self._action_expert().train(mode) + self._set_inference_cuda_graph_enabled(not mode) + return self + + def _freeze_input_embeddings(self) -> None: + embedding_modules: list[torch.nn.Module] = [] + seen_module_ids: set[int] = set() + hf_model = self._hf_model() + for module in (hf_model, self._backbone()): + get_input_embeddings = getattr(module, "get_input_embeddings", None) + if not callable(get_input_embeddings): + continue + embeddings = get_input_embeddings() + if embeddings is None or id(embeddings) in seen_module_ids: + continue + embedding_modules.append(embeddings) + seen_module_ids.add(id(embeddings)) + + if not embedding_modules: + raise RuntimeError("freeze_embedding=true, but MolmoAct2 checkpoint exposes no input embeddings.") + + lm_head = getattr(hf_model, "lm_head", None) + lm_head_params = {id(param) for param in lm_head.parameters()} if lm_head is not None else set() + embedding_params = [param for embeddings in embedding_modules for param in embeddings.parameters()] + if any(id(param) in lm_head_params for param in embedding_params): + raise RuntimeError( + "freeze_embedding=true would also freeze lm_head because input embeddings and lm_head " + "share parameters in this checkpoint." + ) + for param in embedding_params: + param.requires_grad = False + + def get_optim_params(self) -> list[dict[str, Any]]: + vit_params: list[Tensor] = [] + connector_params: list[Tensor] = [] + action_expert_params: list[Tensor] = [] + vlm_params: list[Tensor] = [] + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + if "action_expert" in name: + action_expert_params.append(param) + elif any(part in name for part in ("image_pooling_2d", "image_projector")): + connector_params.append(param) + elif any(part in name for part in ("vision", "image_encoder", "vit")): + vit_params.append(param) + elif any(part in name for part in ("multi_modal_projector", "connector", "mm_projector")): + connector_params.append(param) + else: + vlm_params.append(param) + + vlm_lr = 5e-5 if self.config.enable_lora_vlm else self.config.optimizer_lr + vit_lr = 5e-5 if self.config.enable_lora_vlm else self.config.optimizer_vit_lr + connector_lr = 5e-5 if self.config.enable_lora_vlm else self.config.optimizer_connector_lr + + groups: list[dict[str, Any]] = [] + if vlm_params: + groups.append({"params": vlm_params, "lr": vlm_lr}) + if vit_params: + groups.append({"params": vit_params, "lr": vit_lr}) + if connector_params: + groups.append({"params": connector_params, "lr": connector_lr}) + if action_expert_params: + groups.append({"params": action_expert_params, "lr": self.config.optimizer_action_expert_lr}) + return groups + + def _model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + compute_dtype = _torch_dtype(self.config.model_dtype) + return { + key: value.to(dtype=compute_dtype) if value.is_floating_point() else value + for key, value in batch.items() + if key in _MODEL_INPUT_KEYS and value is not None + } + + def _output_action_dim(self, batch: dict[str, Tensor]) -> int: + action_feature = self.config.output_features.get(ACTION) + if action_feature is not None and action_feature.shape: + action_dim = int(action_feature.shape[0]) + if action_dim > 0: + return action_dim + + action_dim_is_pad = batch.get("action_dim_is_pad") + if action_dim_is_pad is not None: + valid_counts = (~action_dim_is_pad.to(dtype=torch.bool)).sum(dim=-1) + if bool((valid_counts == valid_counts[0]).all()) and int(valid_counts[0]) > 0: + return int(valid_counts[0]) + + raise RuntimeError("MolmoAct2 inference requires a positive action dimension in output_features.") + + def _hf_model(self): + base_model = getattr(self.model, "base_model", None) + wrapped_model = getattr(base_model, "model", None) if base_model is not None else None + return wrapped_model if wrapped_model is not None else self.model + + def _backbone(self): + return self._hf_model().model + + def _override_loaded_max_action_horizon(self, action_horizon: int) -> None: + if action_horizon < 1: + raise ValueError(f"action_horizon must be >= 1, got {action_horizon}.") + hf_model = self._hf_model() + for cfg in (getattr(hf_model, "config", None), getattr(self._backbone(), "config", None)): + if cfg is not None: + cfg.max_action_horizon = int(action_horizon) + + def _generation_action_horizon(self) -> int: + chunk_size = getattr(self.config, "chunk_size", None) + if chunk_size is not None: + return int(chunk_size) + hf_model = self._hf_model() + for cfg in (getattr(hf_model, "config", None), getattr(self._backbone(), "config", None)): + if cfg is None: + continue + value = getattr(cfg, "max_action_horizon", None) + if value is not None: + return int(value) + raise RuntimeError("MolmoAct2 could not resolve an action generation horizon.") + + @staticmethod + def _mask_discrete_action_spans( + *, + input_ids: Tensor, + mask: Tensor, + start_token_id: int | None, + end_token_id: int | None, + ) -> Tensor: + if start_token_id is None or end_token_id is None: + return mask + mask = mask.clone() + for batch_idx in range(input_ids.shape[0]): + row = input_ids[batch_idx] + starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist() + ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist() + end_ptr = 0 + for start in starts: + while end_ptr < len(ends) and ends[end_ptr] < start: + end_ptr += 1 + if end_ptr >= len(ends): + mask[batch_idx, start:] = False + break + end = int(ends[end_ptr]) + mask[batch_idx, start : end + 1] = False + end_ptr += 1 + return mask + + def _encoder_attention_mask_for_action_expert( + self, + *, + input_ids: Tensor | None, + attention_mask: Tensor | None, + ) -> Tensor | None: + backbone = self._backbone() + get_encoder_attention_mask = getattr(backbone, "_get_encoder_attention_mask", None) + if callable(get_encoder_attention_mask): + mask = get_encoder_attention_mask(input_ids, attention_mask) + elif attention_mask is not None: + mask = attention_mask.to(dtype=torch.bool) + elif input_ids is not None: + mask = input_ids != -1 + else: + return None + + if getattr(self.config, "action_mode", None) != "both" or input_ids is None or mask is None: + return mask + + mask = mask.to(dtype=torch.bool).clone() + eos_token_id = getattr(self.model.config, "eos_token_id", None) + if eos_token_id is not None: + mask &= input_ids != int(eos_token_id) + return self._mask_discrete_action_spans( + input_ids=input_ids, + mask=mask, + start_token_id=getattr(self.model.config, "action_start_token_id", None), + end_token_id=getattr(self.model.config, "action_end_token_id", None), + ) + + @staticmethod + def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]: + attention_mask = model_inputs.get("attention_mask") + if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): + model_inputs = dict(model_inputs) + model_inputs.pop("attention_mask", None) + return model_inputs + + def _load_discrete_action_tokenizer(self) -> Any: + if self.action_tokenizer is None: + require_package("transformers", extra="molmoact2") + require_package("scipy", extra="molmoact2") + + if UniversalActionProcessor is None: + raise RuntimeError("transformers and scipy are required to load MolmoAct2 action tokenizer.") + self.action_tokenizer = UniversalActionProcessor.from_pretrained_local( + self.config.discrete_action_tokenizer, + ) + return self.action_tokenizer + + def _resolve_inference_action_mode(self, requested_mode: str | None) -> str: + return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode) + + @staticmethod + def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int: + seed = 0 + for idx in range(batch_size): + seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1) + return seed + + @staticmethod + def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None: + task = batch.get("task") + if task is None: + task = batch.get("observation.language") + if task is None: + return None + if isinstance(task, str): + return (task,) + if isinstance(task, (list, tuple)): + return tuple(str(item) for item in task) + return (str(task),) + + def _rollout_generator_for_inputs( + self, + batch: dict[str, Any], + *, + batch_size: int, + device: torch.device, + ) -> torch.Generator | None: + if not bool(getattr(self.config, "per_episode_seed", False)): + return None + if self._rollout_action_generator is not None: + return self._rollout_action_generator + + task_signature = self._rollout_task_signature(batch) + if task_signature != self._rollout_task_key: + self._rollout_task_key = task_signature + self._rollout_index_for_task = 0 + else: + self._rollout_index_for_task += 1 + + base_seed = int(getattr(self.config, "eval_seed", None) or 0) + first_seed = base_seed + self._rollout_index_for_task * batch_size + generator_device = ( + device if device.type == "cuda" and torch.cuda.is_available() else torch.device("cpu") + ) + generator = torch.Generator(device=generator_device) + generator.manual_seed(self._combine_rollout_seeds(first_seed, batch_size)) + self._rollout_action_generator = generator + return generator + + @staticmethod + def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None: + if mask is None: + return None + return ( + mask.unsqueeze(1) + .expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1))) + .reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:]) + ) + + @staticmethod + def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None: + if action_dim_is_pad is None: + return None + mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) + if mask.ndim == 1: + mask = mask.unsqueeze(0) + if mask.shape[-1] != target.shape[-1]: + raise ValueError( + f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}." + ) + if mask.shape[0] == 1 and target.shape[0] != 1: + mask = mask.expand(target.shape[0], -1) + if mask.shape[0] != target.shape[0]: + raise ValueError( + f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}." + ) + while mask.ndim < target.ndim: + mask = mask.unsqueeze(1) + return mask + + @classmethod + def _mask_action_dim_tensor(cls, tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: + if not cls._mask_enabled_static(action_dim_is_pad): + return tensor + valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad) + if valid_mask is None: + return tensor + return tensor.masked_fill(~valid_mask, 0) + + @staticmethod + def _mask_enabled_static(action_dim_is_pad: Tensor | None) -> bool: + return action_dim_is_pad is not None + + @classmethod + def _apply_action_dim_padding_mask(cls, loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: + valid_mask = cls._action_dim_valid_mask(loss, action_dim_is_pad) + if valid_mask is None: + return loss + valid = valid_mask.to(dtype=loss.dtype) + denom = valid.sum(dim=-1).clamp_min(1.0) + return (loss * valid).sum(dim=-1) / denom + + @staticmethod + def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor: + if action_horizon_is_pad is None: + return loss + valid_action = ( + (~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1) + ) + return loss * valid_action + + def _prepare_flow_matching_tensors( + self, + *, + actions: Tensor, + action_dim_is_pad: Tensor | None, + timesteps: Tensor | None = None, + noise: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + action_expert = self._backbone()._require_action_expert() + action_dtype = next(action_expert.parameters()).dtype + actions = actions.to(dtype=action_dtype) + batch_size = int(actions.shape[0]) + device = actions.device + num_flow_timesteps = max(1, int(self.config.num_flow_timesteps)) + + if timesteps is None: + timesteps = ( + _sample_beta_timesteps( + batch_size=batch_size * num_flow_timesteps, + device=device, + cutoff=self.config.flow_matching_cutoff, + time_offset=self.config.flow_matching_time_offset, + time_scale=self.config.flow_matching_time_scale, + alpha=self.config.flow_matching_beta_alpha, + beta=self.config.flow_matching_beta_beta, + ) + .to(dtype=action_dtype) + .view(batch_size, num_flow_timesteps) + ) + else: + expected_timesteps_shape = (batch_size, num_flow_timesteps) + timesteps = timesteps.to(device=device, dtype=action_dtype) + if tuple(timesteps.shape) != expected_timesteps_shape: + raise ValueError( + f"flow timesteps must have shape {expected_timesteps_shape}, got {tuple(timesteps.shape)}." + ) + + if self.config.mask_action_dim_padding: + actions = self._mask_action_dim_tensor(actions, action_dim_is_pad) + + expected_noise_shape = (batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2]) + if noise is None: + noise = torch.randn(*expected_noise_shape, device=device, dtype=actions.dtype) + else: + noise = noise.to(device=device, dtype=actions.dtype) + if tuple(noise.shape) != expected_noise_shape: + raise ValueError( + f"flow noise must have shape {expected_noise_shape}, got {tuple(noise.shape)}." + ) + if self.config.mask_action_dim_padding: + noise = self._mask_action_dim_tensor(noise, action_dim_is_pad) + + t_broadcast = timesteps.view(batch_size, num_flow_timesteps, 1, 1) + actions_expanded = actions.unsqueeze(1).expand(-1, num_flow_timesteps, -1, -1) + xt = (1.0 - t_broadcast) * noise + t_broadcast * actions_expanded + target_velocity = actions_expanded - noise + return actions, timesteps, xt, target_velocity + + def _prepare_joint_training_backbone_inputs( + self, + model_inputs: dict[str, Tensor], + ) -> tuple[Tensor, Tensor | dict[str, Any], Tensor, Tensor]: + backbone = self._backbone() + input_ids = model_inputs.get("input_ids") + inputs_embeds = model_inputs.get("inputs_embeds") + if (input_ids is None) == (inputs_embeds is None): + raise ValueError( + "MolmoAct2 joint flow training requires exactly one of input_ids or inputs_embeds." + ) + + images = None + token_pooling = None + merge_visual_inputs = getattr(backbone, "merge_visual_inputs", None) + if callable(merge_visual_inputs): + images, token_pooling = merge_visual_inputs( + input_ids=input_ids, + pixel_values=model_inputs.get("pixel_values"), + image_token_pooling=model_inputs.get("image_token_pooling"), + image_grids=model_inputs.get("image_grids"), + image_num_crops=model_inputs.get("image_num_crops"), + pixel_values_videos=model_inputs.get("pixel_values_videos"), + video_token_pooling=model_inputs.get("video_token_pooling"), + video_grids=model_inputs.get("video_grids"), + ) + elif ( + model_inputs.get("pixel_values") is not None + or model_inputs.get("pixel_values_videos") is not None + ): + raise RuntimeError("MolmoAct2 checkpoint does not expose merge_visual_inputs for joint training.") + + if images is not None and inputs_embeds is not None: + raise ValueError("MolmoAct2 joint flow training cannot combine inputs_embeds with visual inputs.") + if inputs_embeds is None: + inputs_embeds, _image_features = backbone.build_input_embeddings(input_ids, images, token_pooling) + + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + position_ids = model_inputs.get("position_ids") + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + attention_mask = model_inputs.get("attention_mask") + if isinstance(attention_mask, dict): + causal_mask_mapping = attention_mask + else: + causal_mask_mapping = backbone._build_native_attention_bias( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + token_type_ids=model_inputs.get("token_type_ids"), + past_key_values=None, + ) + return inputs_embeds, causal_mask_mapping, position_ids, cache_position + + @staticmethod + def _decoder_layer_kv_outputs( + layer_outputs: tuple[Any, ...], *, output_attentions: bool + ) -> tuple[Tensor, Tensor]: + output_idx = 2 if output_attentions else 1 + return layer_outputs[output_idx], layer_outputs[output_idx + 1] + + @staticmethod + def _action_time_conditioning(action_expert: torch.nn.Module, timesteps: Tensor) -> Tensor: + time_conditioning = getattr(action_expert, "_time_conditioning", None) + if callable(time_conditioning): + return time_conditioning(timesteps) + return action_expert.time_embed(timesteps) + + def _compute_flow_matching_loss_joint_per_layer( + self, + *, + batch: dict[str, Tensor], + model_inputs: dict[str, Tensor], + timesteps: Tensor | None = None, + noise: Tensor | None = None, + reduction: str = "mean", + ) -> tuple[Tensor, Tensor]: + if reduction not in {"mean", "none"}: + raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") + backbone = self._backbone() + transformer = getattr(backbone, "transformer", None) + action_expert = backbone._require_action_expert() + if transformer is None: + raise RuntimeError("MolmoAct2 joint flow training requires a patchable text transformer.") + if len(action_expert.blocks) != int(transformer.config.num_hidden_layers): + raise RuntimeError( + "MolmoAct2 joint flow training requires one action expert block per text transformer layer." + ) + + actions, timesteps, xt, target_velocity = self._prepare_flow_matching_tensors( + actions=batch[ACTION], + action_dim_is_pad=batch.get("action_dim_is_pad"), + timesteps=timesteps, + noise=noise, + ) + num_flow_timesteps = max(1, int(self.config.num_flow_timesteps)) + batch_size = int(actions.shape[0]) + device = actions.device + xt_flat = xt.reshape(batch_size * num_flow_timesteps, actions.shape[1], actions.shape[2]) + timesteps_flat = timesteps.reshape(batch_size * num_flow_timesteps) + + hidden_states, causal_mask_mapping, position_ids, cache_position = ( + self._prepare_joint_training_backbone_inputs(model_inputs) + ) + if hidden_states.shape[0] != batch_size: + raise ValueError( + f"Backbone batch size {hidden_states.shape[0]} does not match action batch size {batch_size}." + ) + + encoder_attention_mask = self._encoder_attention_mask_for_action_expert( + input_ids=model_inputs.get("input_ids"), + attention_mask=model_inputs.get("attention_mask"), + ) + action_attention_mask = None + if batch.get("action_horizon_is_pad") is not None: + action_attention_mask = ~batch["action_horizon_is_pad"].to(device=device, dtype=torch.bool) + + valid_action = None + if action_attention_mask is not None: + valid_action = action_attention_mask.to(device=device, dtype=actions.dtype).unsqueeze(-1) + valid_action = self._expand_mask(valid_action, num_flow_timesteps) + + rope_cache = None + if len(action_expert.blocks) > 0 and action_expert.blocks[0].self_attn.rope is not None: + rope_cache = action_expert.blocks[0].self_attn.rope.build_cache( + seq_len=actions.shape[1], + device=device, + dtype=actions.dtype, + ) + + cross_mask = action_expert._build_cross_attention_mask( + encoder_attention_mask, + batch_size, + actions.dtype, + ) + cross_mask = self._expand_mask(cross_mask, num_flow_timesteps) + self_mask = action_expert._build_self_attention_mask( + action_attention_mask, + actions.shape[1], + device, + actions.dtype, + ) + self_mask = self._expand_mask(self_mask, num_flow_timesteps) + + conditioning = self._action_time_conditioning(action_expert, timesteps_flat) + action_hidden = action_expert.action_embed(xt_flat) + if valid_action is not None: + action_hidden = action_hidden * valid_action + + if transformer.config.rope_scaling_layers is not None: + position_embeddings_mapping = { + "default": transformer.rotary_embs["default"](hidden_states, position_ids), + "scaling": transformer.rotary_embs["scaling"](hidden_states, position_ids), + } + else: + position_embeddings = transformer.rotary_emb(hidden_states, position_ids) + + use_gradient_checkpointing = bool( + getattr(self.config, "gradient_checkpointing", False) + and self.training + and torch.is_grad_enabled() + ) + + def run_layer( + layer_idx: int, layer_hidden: Tensor, layer_action_hidden: Tensor + ) -> tuple[Tensor, Tensor]: + decoder_block = transformer.blocks[layer_idx] + action_block = action_expert.blocks[layer_idx] + if transformer.config.rope_scaling_layers is not None: + position_embeddings_i = ( + position_embeddings_mapping["scaling"] + if layer_idx in transformer.config.rope_scaling_layers + else position_embeddings_mapping["default"] + ) + else: + position_embeddings_i = position_embeddings + + layer_outputs = decoder_block( + layer_hidden, + position_embeddings=position_embeddings_i, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=None, + output_attentions=False, + use_cache=False, + cache_position=cache_position, + collect_layer_kv_states=True, + ) + next_hidden = layer_outputs[0] + key_states, value_states = self._decoder_layer_kv_outputs(layer_outputs, output_attentions=False) + key_states = backbone._cache_to_sequence(key_states) + value_states = backbone._cache_to_sequence(value_states) + if self.config.enable_knowledge_insulation: + key_states = key_states.detach() + value_states = value_states.detach() + + k_ctx = action_expert._project_kv_tensor(key_states, action_expert.context_k_proj) + v_ctx = action_expert._project_kv_tensor(value_states, action_expert.context_v_proj) + k_norm = action_block.cross_attn.k_norm + if k_norm is not None: + k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2) + if num_flow_timesteps != 1: + k_ctx = self._expand_mask(k_ctx, num_flow_timesteps) + v_ctx = self._expand_mask(v_ctx, num_flow_timesteps) + + next_action_hidden = action_block( + layer_action_hidden, + conditioning, + cross_kv=(k_ctx, v_ctx), + self_attn_mask=self_mask, + attn_mask=cross_mask, + is_causal=action_expert.config.causal_attn, + modulation=None, + rope_cache=rope_cache, + ) + if valid_action is not None: + next_action_hidden = next_action_hidden * valid_action + return next_hidden, next_action_hidden + + for layer_idx in range(int(transformer.config.num_hidden_layers)): + if use_gradient_checkpointing: + hidden_states, action_hidden = torch.utils.checkpoint.checkpoint( + lambda layer_hidden, layer_action_hidden, idx=layer_idx: run_layer( + idx, + layer_hidden, + layer_action_hidden, + ), + hidden_states, + action_hidden, + use_reentrant=False, + ) + else: + hidden_states, action_hidden = run_layer(layer_idx, hidden_states, action_hidden) + + hidden_states = transformer.ln_f(hidden_states) + pred_velocity = action_expert.final_layer(action_hidden, conditioning) + if valid_action is not None: + pred_velocity = pred_velocity * valid_action + pred_velocity = pred_velocity.reshape( + batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2] + ) + + loss = F.mse_loss(pred_velocity, target_velocity, reduction="none") + loss = self._apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad")) + if self.config.mask_action_dim_padding: + loss = self._apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad")) + loss = loss.reshape(batch_size, -1).mean(dim=1) + if reduction == "mean": + loss = loss.mean() + return loss, hidden_states + + def _discrete_token_weights(self, valid_positions: Tensor) -> Tensor | None: + mode = self.config.discrete_loss_token_weighting + if mode in {"none", "token", "root_subsegments"}: + return None + if mode != "root_subsegments_root_tokens" and mode != "root_tokens": + raise ValueError(f"Unsupported discrete_loss_token_weighting={mode!r}.") + + token_counts = valid_positions.sum(dim=1).to(dtype=torch.float32) + example_weights = torch.zeros_like(token_counts) + nonempty = token_counts > 0 + example_weights[nonempty] = 2.0 / torch.sqrt(token_counts[nonempty]) + return example_weights[:, None].expand_as(valid_positions)[valid_positions].to(dtype=torch.float32) + + @staticmethod + def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor: + if weights is None: + return values.mean() + weights = weights.to(device=values.device, dtype=values.dtype) + return torch.dot(values, weights) / weights.sum().clamp_min(1.0) + + @staticmethod + def _weighted_per_example( + values: Tensor, + weights: Tensor | None, + example_indices: Tensor, + batch_size: int, + ) -> Tensor: + values = values.float() + if weights is None: + weights = torch.ones_like(values) + else: + weights = weights.to(device=values.device, dtype=values.dtype) + loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) + weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) + loss_sum.scatter_add_(0, example_indices, values * weights) + weight_sum.scatter_add_(0, example_indices, weights) + global_weight_sum = weight_sum.sum().clamp_min(1.0) + return loss_sum * float(batch_size) / global_weight_sum + + def _discrete_loss_from_backbone_outputs( + self, + batch: dict[str, Tensor], + outputs: Any, + reduction: str = "mean", + ) -> tuple[Tensor, Tensor | None]: + if reduction not in {"mean", "none"}: + raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") + labels = batch.get("labels") + if labels is None: + raise RuntimeError("MolmoAct2 discrete training requires labels.") + hidden_states = outputs.last_hidden_state + if hidden_states is None: + raise RuntimeError("MolmoAct2 backbone did not return last_hidden_state.") + + ignore_index = -100 + shift_labels = F.pad(labels, (0, 1), value=ignore_index)[..., 1:].contiguous() + valid_positions = shift_labels != ignore_index + if not bool(valid_positions.any()): + raise RuntimeError("MolmoAct2 discrete training labels contain no valid action tokens.") + + hidden_size = hidden_states.shape[-1] + selected_hidden = hidden_states.reshape(-1, hidden_size)[valid_positions.reshape(-1)] + selected_labels = shift_labels.reshape(-1)[valid_positions.reshape(-1)].to( + device=hidden_states.device + ) + logits = F.linear(selected_hidden, self.model.lm_head.weight).float() + log_z = logits.logsumexp(dim=-1) + target_logits = logits.gather(dim=-1, index=selected_labels[:, None]).squeeze(-1) + token_ce_loss = log_z - target_logits + token_weights = self._discrete_token_weights(valid_positions) + if reduction == "none": + example_indices = valid_positions.nonzero(as_tuple=False)[:, 0].to(device=hidden_states.device) + ce_loss = self._weighted_per_example( + token_ce_loss, + token_weights, + example_indices, + int(labels.shape[0]), + ) + else: + ce_loss = self._weighted_mean(token_ce_loss, token_weights) + if not self.config.softmax_auxiliary_loss: + return ce_loss, None + + if reduction == "none": + z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_per_example( + log_z.pow(2), + token_weights, + example_indices, + int(labels.shape[0]), + ) + else: + z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_mean( + log_z.pow(2), token_weights + ) + return ce_loss, z_loss + + @staticmethod + def _extract_discrete_token_bins( + generated_ids: list[int], + start_token_id: int, + end_token_id: int, + token_id_to_bin: dict[int, int], + ) -> list[int]: + start_idx = None + end_idx = None + for idx, token_id in enumerate(generated_ids): + if token_id == start_token_id: + start_idx = idx + break + if start_idx is not None: + for idx in range(start_idx + 1, len(generated_ids)): + if generated_ids[idx] == end_token_id: + end_idx = idx + break + span_start = 0 if start_idx is None else start_idx + 1 + span_end = len(generated_ids) if end_idx is None else end_idx + return [ + int(token_id_to_bin[token_id]) + for token_id in generated_ids[span_start:span_end] + if token_id in token_id_to_bin + ] + + def _action_token_id_to_bin(self) -> dict[int, int]: + method = getattr(self.model, "_action_token_id_to_bin", None) + if callable(method): + return dict(method()) + start = getattr(self.model.config, "action_token_start_id", None) + num_tokens = int(getattr(self.model.config, "num_action_tokens", 0) or 0) + if start is None or num_tokens <= 0: + return {} + return {int(start) + idx: idx for idx in range(num_tokens)} + + def _require_discrete_eos_token_id(self) -> int: + method = getattr(self.model, "_require_eos_token_id", None) + if callable(method): + return int(method()) + eos_token_id = getattr(self.model.config, "eos_token_id", None) + if eos_token_id is None and getattr(self.model, "generation_config", None) is not None: + eos_token_id = getattr(self.model.generation_config, "eos_token_id", None) + if isinstance(eos_token_id, (list, tuple)): + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None: + raise RuntimeError("Discrete action generation requires eos_token_id in the checkpoint config.") + return int(eos_token_id) + + def _discrete_generation_max_steps(self) -> int: + if self.config.discrete_generation_max_steps is not None: + return int(self.config.discrete_generation_max_steps) + return max(1, self._generation_action_horizon() * 16) + + def _continue_discrete_generation_from_output( + self, + initial_output: Any, + *, + past_key_values: Any | None, + attention_mask: Tensor | None, + end_token_id: int, + max_steps: int, + attention_bias: Tensor | None = None, + ) -> Tensor: + consume_generation_tokens = getattr(self.model, "_consume_generation_tokens", None) + ar_decode_step = getattr(self.model, "_run_ar_decode_step", None) + if ar_decode_step is None: + ar_decode_step = getattr(self.model, "_run_depth_decode_step", None) + if attention_bias is None and not callable(consume_generation_tokens): + raise RuntimeError("MolmoAct2 checkpoint does not expose discrete token generation helpers.") + if attention_bias is not None and not callable(ar_decode_step): + raise RuntimeError("MolmoAct2 checkpoint does not expose graph-backed AR decode helpers.") + + generated_tokens: list[Tensor] = [] + current_output = initial_output + current_past_key_values = past_key_values + current_attention_mask = attention_mask + hit_end = False + for _ in range(int(max_steps)): + next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) + generated_tokens.append(next_token) + if bool((next_token == int(end_token_id)).all()): + hit_end = True + break + if attention_bias is None: + current_output, current_attention_mask = consume_generation_tokens( + next_token, + past_key_values=current_past_key_values, + attention_mask=current_attention_mask, + ) + current_past_key_values = current_output.past_key_values + else: + last_hidden, current_past_key_values = ar_decode_step( + next_token, + past_key_values=current_past_key_values, + attention_bias=attention_bias, + ) + current_output = types.SimpleNamespace( + logits=self.model.lm_head(last_hidden), + past_key_values=current_past_key_values, + ) + if not generated_tokens: + raise RuntimeError("Discrete continuation generated no tokens.") + if not hit_end: + raise RuntimeError( + f"Discrete continuation did not emit end token {int(end_token_id)} within {int(max_steps)} steps." + ) + return torch.stack(generated_tokens, dim=1) + + def _make_discrete_ar_graph_decode_inputs( + self, + model_inputs: dict[str, Tensor], + *, + max_steps: int, + ) -> tuple[Any | None, Tensor | None]: + if not bool(getattr(self.config, "enable_inference_cuda_graph", False)): + return None, None + if self.training or self.model.training: + return None, None + ar_decode_step = getattr(self.model, "_run_ar_decode_step", None) + if ar_decode_step is None: + ar_decode_step = getattr(self.model, "_run_depth_decode_step", None) + make_attention_bias = getattr(self.model, "_make_depth_decode_attention_bias", None) + if not callable(ar_decode_step) or not callable(make_attention_bias): + return None, None + + make_static_cache = getattr(self.model, "_make_ar_decode_static_cache", None) + if callable(make_static_cache): + static_cache = make_static_cache(model_inputs, max_steps=max_steps) + else: + graph_manager = getattr(self.model, "depth_decode_cuda_graph_manager", None) + make_manager_static_cache = getattr(graph_manager, "make_static_cache", None) + if not callable(make_manager_static_cache): + return None, None + prompt_len = int(model_inputs["input_ids"].shape[1]) + static_cache = make_manager_static_cache(max_cache_len=prompt_len + max(1, int(max_steps))) + + attention_bias = make_attention_bias(model_inputs, static_cache) + return static_cache, attention_bias + + def _decode_discrete_action_chunk(self, generated_token_ids: Tensor, *, action_dim: int) -> Tensor: + if ( + getattr(self.model.config, "action_start_token_id", None) is None + or getattr(self.model.config, "action_end_token_id", None) is None + ): + raise RuntimeError("Discrete action generation requires / token IDs.") + token_id_to_bin = self._action_token_id_to_bin() + if not token_id_to_bin: + raise RuntimeError( + "Discrete action generation requires indexed action tokens in the checkpoint config." + ) + + action_tokenizer = self._load_discrete_action_tokenizer() + if generated_token_ids.ndim == 1: + generated_token_ids = generated_token_ids.unsqueeze(0) + if generated_token_ids.ndim == 3: + generated_token_ids = generated_token_ids[:, 0, :] + if generated_token_ids.ndim != 2: + raise ValueError(f"Unexpected generated token tensor shape {tuple(generated_token_ids.shape)}.") + + chunks: list[Tensor] = [] + for token_row in generated_token_ids: + generated_ids = [int(token_id) for token_id in token_row.detach().cpu().tolist()] + discrete_token_ids = self._extract_discrete_token_bins( + generated_ids, + int(self.model.config.action_start_token_id), + int(self.model.config.action_end_token_id), + token_id_to_bin, + ) + if not discrete_token_ids: + raise RuntimeError( + "Model generated no decodable action tokens between /." + ) + try: + decoded = action_tokenizer.decode( + [discrete_token_ids], + time_horizon=self._generation_action_horizon(), + action_dim=int(action_dim), + ) + except TypeError: + decoded = action_tokenizer.decode([discrete_token_ids]) + action_chunk = np.asarray(decoded, dtype=np.float32) + if action_chunk.ndim == 1: + action_chunk = action_chunk[None, :] + elif action_chunk.ndim == 3: + if int(action_chunk.shape[0]) != 1: + action_chunk = action_chunk.reshape(action_chunk.shape[-2], action_chunk.shape[-1]) + else: + action_chunk = action_chunk[0] + elif action_chunk.ndim > 3: + action_chunk = action_chunk.reshape(action_chunk.shape[-2], action_chunk.shape[-1]) + if action_chunk.ndim != 2: + raise RuntimeError(f"Decoded action chunk has unexpected shape {action_chunk.shape}.") + chunks.append(torch.as_tensor(action_chunk, device=token_row.device, dtype=torch.float32)) + return torch.stack(chunks, dim=0) + + def _generate_discrete_actions_from_inputs( + self, + *, + model_inputs: dict[str, Tensor], + action_dim: int, + ) -> Tensor: + model_inputs = self._drop_trivial_attention_mask(model_inputs) + max_steps = self._discrete_generation_max_steps() + static_cache, attention_bias = self._make_discrete_ar_graph_decode_inputs( + model_inputs, + max_steps=max_steps, + ) + prefill_kwargs: dict[str, Any] = {} + if static_cache is not None: + prefill_kwargs["past_key_values"] = static_cache + prefill_output = self.model( + **model_inputs, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + **prefill_kwargs, + ) + generated_token_ids = self._continue_discrete_generation_from_output( + prefill_output, + past_key_values=prefill_output.past_key_values, + attention_mask=model_inputs.get("attention_mask"), + end_token_id=self._require_discrete_eos_token_id(), + max_steps=max_steps, + attention_bias=attention_bias, + ) + return self._decode_discrete_action_chunk(generated_token_ids, action_dim=action_dim) + + def _generate_actions_from_inputs_with_rtc( + self, + *, + model_inputs: dict[str, Tensor], + action_dim_is_pad: Tensor | None, + num_steps: int | None, + generator: torch.Generator | None, + inference_delay: int | None, + prev_chunk_left_over: Tensor | None, + execution_horizon: int | None, + ) -> Tensor: + backbone = self._backbone() + action_expert = self._action_expert() + outputs = backbone( + **model_inputs, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + ) + encoder_kv_states = backbone._extract_kv_states(outputs.past_key_values) + encoder_attention_mask = self._encoder_attention_mask_for_action_expert( + input_ids=model_inputs.get("input_ids"), + attention_mask=model_inputs.get("attention_mask"), + ) + depth_gate, depth_mask = backbone._depth_gate_from_condition( + input_ids=model_inputs.get("input_ids"), + encoder_attention_mask=encoder_attention_mask, + layer_kv_states=encoder_kv_states, + ) + encoder_kv_states = backbone._apply_depth_gate_to_layer_kv_states( + encoder_kv_states, + depth_mask, + depth_gate, + ) + + steps = int(num_steps or backbone.config.flow_matching_num_steps) + if steps <= 0: + raise ValueError(f"num_steps must be >= 1, got {steps}.") + source_tensor = encoder_kv_states[0][0] + batch_size = int(source_tensor.shape[0]) + device = source_tensor.device + trajectory = torch.randn( + batch_size, + self._generation_action_horizon(), + int(backbone.config.max_action_dim), + device=device, + dtype=torch.float32, + generator=generator, + ) + if self.config.mask_action_dim_padding: + trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad) + + action_context = action_expert.prepare_context( + encoder_kv_states=encoder_kv_states, + encoder_attention_mask=encoder_attention_mask, + state_embeddings=None, + batch_size=batch_size, + seq_len=trajectory.shape[1], + device=device, + dtype=trajectory.dtype, + ) + flow_timesteps = [ + torch.full((batch_size,), idx / steps, device=device, dtype=trajectory.dtype) + for idx in range(steps) + ] + modulation_cache = action_expert.get_or_prepare_modulation_cache( + flow_timesteps, + cache_key=(steps, batch_size, device, trajectory.dtype), + ) + + dt = 1.0 / steps + mask_enabled = self.config.mask_action_dim_padding + for idx, flow_timestep in enumerate(flow_timesteps): + modulation = modulation_cache[idx] + + def denoise_step(input_trajectory: Tensor, step_modulation=modulation) -> Tensor: + velocity = action_expert.forward_with_context( + input_trajectory, + step_modulation.conditioning, + context=action_context, + modulation=step_modulation, + ) + if mask_enabled: + velocity = self._mask_action_dim_tensor(velocity, action_dim_is_pad) + return velocity + + if self._rtc_enabled(): + if self.rtc_processor is None: + raise RuntimeError("RTC is enabled but rtc_processor is not initialized.") + + def rtc_denoise_step(input_trajectory: Tensor) -> Tensor: + return -denoise_step(input_trajectory) + + rtc_time = 1.0 - float(flow_timestep[0].item()) + rtc_velocity = self.rtc_processor.denoise_step( + x_t=trajectory, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=int(inference_delay or 0), + time=rtc_time, + original_denoise_step_partial=rtc_denoise_step, + execution_horizon=execution_horizon, + ) + velocity = -rtc_velocity + else: + velocity = denoise_step(trajectory) + + trajectory = trajectory + dt * velocity + if mask_enabled: + trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad) + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=float(flow_timestep[0].item()), x_t=trajectory, v_t=velocity) + + return trajectory + + def forward( + self, + batch: dict[str, Tensor], + reduction: str = "mean", + ) -> tuple[Tensor, dict[str, Any]]: + if reduction not in {"mean", "none"}: + raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") + model_inputs = self._model_inputs(batch) + losses: list[Tensor] = [] + metrics: dict[str, Any] = {} + + if self.config.action_mode == "discrete": + outputs = self._backbone()( + **model_inputs, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + ) + discrete_ce_loss, discrete_z_loss = self._discrete_loss_from_backbone_outputs( + batch, outputs, reduction=reduction + ) + discrete_loss = ( + discrete_ce_loss if discrete_z_loss is None else discrete_ce_loss + discrete_z_loss + ) + losses.append(discrete_loss) + metrics["discrete_ce_loss"] = discrete_ce_loss.detach().float().mean().item() + if discrete_z_loss is not None: + metrics["discrete_z_loss"] = discrete_z_loss.detach().float().mean().item() + + elif self.config.action_mode == "continuous": + flow_loss, _ = self._compute_flow_matching_loss_joint_per_layer( + batch=batch, + model_inputs=model_inputs, + reduction=reduction, + ) + losses.append(flow_loss) + metrics["action_flow_loss"] = flow_loss.detach().float().mean().item() + + else: + flow_loss, hidden_states = self._compute_flow_matching_loss_joint_per_layer( + batch=batch, + model_inputs=model_inputs, + reduction=reduction, + ) + outputs = types.SimpleNamespace(last_hidden_state=hidden_states) + discrete_ce_loss, discrete_z_loss = self._discrete_loss_from_backbone_outputs( + batch, outputs, reduction=reduction + ) + discrete_loss = ( + discrete_ce_loss if discrete_z_loss is None else discrete_ce_loss + discrete_z_loss + ) + losses.append(discrete_loss) + metrics["discrete_ce_loss"] = discrete_ce_loss.detach().float().mean().item() + if discrete_z_loss is not None: + metrics["discrete_z_loss"] = discrete_z_loss.detach().float().mean().item() + losses.append(flow_loss) + metrics["action_flow_loss"] = flow_loss.detach().float().mean().item() + + loss = torch.stack(losses).sum(dim=0) + metrics["loss"] = loss.detach().float().mean().item() + return loss, metrics + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + if "action_mode" in kwargs: + raise TypeError( + "MolmoAct2 predict_action_chunk got unexpected keyword argument 'action_mode'; " + "use 'inference_action_mode'." + ) + model_inputs = self._model_inputs(batch) + inference_action_mode = self._resolve_inference_action_mode(kwargs.get("inference_action_mode")) + num_steps = kwargs.get("num_steps", getattr(self.config, "num_inference_steps", None)) + generator = kwargs.get("generator") + model_dtype = _torch_dtype(self.config.model_dtype) + device = next(self.parameters()).device + batch_size = int(next(iter(model_inputs.values())).shape[0]) + if generator is None: + generator = self._rollout_generator_for_inputs( + batch, + batch_size=batch_size, + device=device, + ) + action_dim = self._output_action_dim(batch) + autocast_context = ( + torch.autocast(device_type=device.type, dtype=model_dtype) + if device.type in {"cuda", "cpu"} and model_dtype in {torch.bfloat16, torch.float16} + else nullcontext() + ) + with autocast_context: + if inference_action_mode == "discrete": + if self._rtc_enabled(): + raise ValueError("RTC is only supported for continuous MolmoAct2 inference.") + actions = self._generate_discrete_actions_from_inputs( + model_inputs=model_inputs, + action_dim=action_dim, + ) + elif self._rtc_enabled(): + actions = self._generate_actions_from_inputs_with_rtc( + model_inputs=model_inputs, + action_dim_is_pad=batch.get("action_dim_is_pad"), + num_steps=num_steps, + generator=generator, + inference_delay=kwargs.get("inference_delay"), + prev_chunk_left_over=kwargs.get("prev_chunk_left_over"), + execution_horizon=kwargs.get("execution_horizon"), + ) + else: + actions = self._backbone().generate_actions_from_inputs( + **model_inputs, + action_dim_is_pad=batch.get("action_dim_is_pad"), + action_horizon=self._generation_action_horizon(), + num_steps=num_steps, + generator=generator, + ) + return actions[:, : self.config.n_action_steps, :action_dim].to(dtype=torch.float32) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + if self._rtc_enabled(): + raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk") + self.eval() + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps] + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def _get_default_peft_targets(self) -> dict[str, Any]: + target_modules = self._lora_target_modules(prefix=r"model\.model") + return { + "target_modules": target_modules, + "modules_to_save": [], + "r": self.config.lora_rank, + "lora_alpha": self.config.lora_alpha, + "lora_dropout": self.config.lora_dropout, + "bias": self.config.lora_bias, + } + + def _get_inner_peft_targets(self) -> dict[str, Any]: + target_modules = self._lora_target_modules(prefix="model") + return { + "target_modules": target_modules, + "modules_to_save": [], + "r": self.config.lora_rank, + "lora_alpha": self.config.lora_alpha, + "lora_dropout": self.config.lora_dropout, + "bias": self.config.lora_bias, + } + + def _lora_target_modules(self, *, prefix: str) -> str: + vlm_linear_leaves = "w1|w2|w3|wq|wk|wv|wo|att_proj|attn_out|ff_proj|ff_out|patch_embedding" + target_modules = rf"{prefix}\.(transformer|vision_backbone)\.(?:.*\.)?({vlm_linear_leaves})$" + if self.config.enable_lora_action_expert: + action_expert_linear_paths = ( + r"time_embed\.(1|3)|" + r"action_embed|context_k_proj|context_v_proj|" + r"blocks\.\d+\.self_attn\.(qkv|out_proj)|" + r"blocks\.\d+\.cross_attn\.(q_proj|out_proj)|" + r"blocks\.\d+\.mlp\.(up_proj|gate_proj|down_proj)|" + r"blocks\.\d+\.modulation\.linear|" + r"final_layer\.(modulation\.linear|linear)" + ) + target_modules = ( + f"({target_modules}|" + rf"{prefix}\.action_expert\.({action_expert_linear_paths})$)" + ) + return target_modules + + def _build_inner_lora_config(self): + require_package("peft", extra="molmoact2") + from peft import LoraConfig + + return LoraConfig(**self._get_inner_peft_targets()) + + def _apply_lora_adapters(self) -> None: + require_package("peft", extra="molmoact2") + from peft import get_peft_model + + peft_config = self._build_inner_lora_config() + self._validate_peft_config(peft_config) + + for param in self.model.parameters(): + param.requires_grad_(False) + self.model = get_peft_model(self.model, peft_config) + if not self.config.enable_lora_action_expert: + self._unfreeze_action_expert_parameters() + self.train(self.training) + + def _validate_peft_config(self, peft_config) -> None: + del peft_config + if not self.config.checkpoint_path: + raise ValueError("MolmoAct2 LoRA fine-tuning requires `policy.checkpoint_path`.") diff --git a/src/lerobot/policies/molmoact2/processor_molmoact2.py b/src/lerobot/policies/molmoact2/processor_molmoact2.py new file mode 100644 index 000000000..6c7a3ed5c --- /dev/null +++ b/src/lerobot/policies/molmoact2/processor_molmoact2.py @@ -0,0 +1,1083 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import re +from contextlib import suppress +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from huggingface_hub import snapshot_download +from torch import Tensor + +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.types import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) +from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package + +from .configuration_molmoact2 import MolmoAct2Config, infer_molmoact2_max_sequence_length + +if TYPE_CHECKING or _transformers_available: + from transformers import Qwen2Tokenizer + + from .hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor + from .hf_model.processing_molmoact2 import MolmoAct2Processor + from .hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor +else: + Qwen2Tokenizer = None + MolmoAct2ImageProcessor = None + MolmoAct2Processor = None + MolmoAct2VideoProcessor = None + +if TYPE_CHECKING or (_transformers_available and _scipy_available): + from .hf_model.action_tokenizer import UniversalActionProcessor +else: + UniversalActionProcessor = None + +ACTION_OUTPUT_TOKEN = "" # nosec B105 +ACTION_START_TOKEN = "" # nosec B105 +ACTION_END_TOKEN = "" # nosec B105 +ACTION_TOKEN_PREFIX = " str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_checkpoint_location( + checkpoint_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + checkpoint_path = str(checkpoint_path or "").strip() + if not checkpoint_path: + raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") + local_path = Path(checkpoint_path).expanduser() + if local_path.exists(): + return str(local_path) + return snapshot_download( + repo_id=checkpoint_path, + repo_type="model", + revision=revision, + force_download=force_download, + ignore_patterns=["*.py", "*.pyc", "__pycache__/*"], + token=_hf_token(), + ) + + +def _load_hf_norm_stats_for_tag( + checkpoint_path: str, + *, + revision: str | None, + force_download: bool, + norm_tag: str | None, +) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]: + norm_tag = str(norm_tag or "").strip() + if not norm_tag: + raise ValueError("MolmoAct2 HF checkpoint inference requires `policy.norm_tag` for normalization.") + + checkpoint_location = Path( + _resolve_checkpoint_location( + checkpoint_path, + revision=revision, + force_download=force_download, + ) + ) + config_path = checkpoint_location / "config.json" + norm_stats_filename = "norm_stats.json" + if config_path.exists(): + with suppress(OSError, json.JSONDecodeError): + norm_stats_filename = str( + json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename + ) + + stats_path = checkpoint_location / norm_stats_filename + if not stats_path.exists(): + raise FileNotFoundError( + f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." + ) + payload = json.loads(stats_path.read_text()) + metadata_by_tag = payload.get("metadata_by_tag") + if not isinstance(metadata_by_tag, dict): + raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") + metadata = metadata_by_tag.get(norm_tag) + if metadata is None: + available = sorted(str(tag) for tag in metadata_by_tag) + raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") + if not isinstance(metadata, dict): + raise ValueError(f"MolmoAct2 norm_tag={norm_tag!r} metadata must be a mapping.") + + def numeric_stats(raw_stats: dict[str, Any]) -> dict[str, Any]: + stats: dict[str, Any] = {} + for key, value in raw_stats.items(): + if key == "names": + continue + if isinstance(value, (list, tuple)) and any(isinstance(item, str) for item in value): + continue + stats[key] = deepcopy(value) + return stats + + action_stats = metadata.get("action_stats") + state_stats = metadata.get("state_stats") + if not isinstance(action_stats, dict) or not isinstance(state_stats, dict): + raise ValueError(f"MolmoAct2 norm_tag={norm_tag!r} must define action_stats and state_stats.") + return {ACTION: numeric_stats(action_stats), OBS_STATE: numeric_stats(state_stats)}, metadata + + +def _strip_processor_config(config: dict[str, Any], *metadata_keys: str) -> dict[str, Any]: + return { + key: value + for key, value in config.items() + if key not in {"auto_map", "processor_class", *metadata_keys} + } + + +def _load_local_molmoact2_processor(checkpoint_location: str) -> Any: + if ( + Qwen2Tokenizer is None + or MolmoAct2ImageProcessor is None + or MolmoAct2Processor is None + or MolmoAct2VideoProcessor is None + ): + raise RuntimeError("transformers is required to load MolmoAct2 processor.") + + checkpoint_path = Path(checkpoint_location) + processor_config_path = checkpoint_path / "processor_config.json" + if not processor_config_path.exists(): + raise FileNotFoundError(f"MolmoAct2 checkpoint is missing {processor_config_path}.") + processor_config = json.loads(processor_config_path.read_text()) + + image_config = _strip_processor_config( + dict(processor_config.get("image_processor") or {}), + "image_processor_type", + ) + video_config = _strip_processor_config( + dict(processor_config.get("video_processor") or {}), + "video_processor_type", + ) + image_processor = MolmoAct2ImageProcessor(**image_config) + video_processor = MolmoAct2VideoProcessor(**video_config) + tokenizer = Qwen2Tokenizer.from_pretrained( + checkpoint_location, + token=_hf_token(), + ) + + chat_template_path = checkpoint_path / "chat_template.jinja" + chat_template = chat_template_path.read_text() if chat_template_path.exists() else None + return MolmoAct2Processor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=tokenizer, + chat_template=chat_template, + image_use_col_tokens=processor_config.get("image_use_col_tokens", True), + use_single_crop_col_tokens=processor_config.get("use_single_crop_col_tokens"), + use_single_crop_start_token=processor_config.get("use_single_crop_start_token", True), + video_use_col_tokens=processor_config.get("video_use_col_tokens", False), + use_frame_special_tokens=processor_config.get("use_frame_special_tokens", True), + ) + + +def _to_numpy(value: Any) -> np.ndarray: + if isinstance(value, np.ndarray): + return value + if torch.is_tensor(value): + return value.detach().cpu().numpy() + return np.asarray(value) + + +def _normalize_image(value: Any) -> np.ndarray: + arr = _to_numpy(value) + while arr.ndim > 3 and int(arr.shape[0]) == 1: + arr = arr[0] + if arr.ndim == 2: + arr = np.stack([arr] * 3, axis=-1) + if arr.ndim == 3 and arr.shape[0] in {1, 3, 4} and arr.shape[-1] not in {1, 3, 4}: + arr = np.moveaxis(arr, 0, -1) + if arr.ndim == 3 and arr.shape[-1] == 1: + arr = np.repeat(arr, 3, axis=-1) + if arr.ndim != 3 or arr.shape[-1] not in {3, 4}: + raise ValueError(f"Unsupported image shape for MolmoAct2: {arr.shape}.") + if arr.shape[-1] == 4: + arr = arr[..., :3] + if arr.dtype in (np.float16, np.float32, np.float64): + if arr.size > 0 and float(np.nanmax(arr)) <= 1.0: + arr = arr * 255.0 + arr = np.clip(arr, 0, 255).astype(np.uint8) + elif arr.dtype != np.uint8: + arr = np.clip(arr, 0, 255).astype(np.uint8) + return arr + + +def _normalize_question_text(text: str) -> str: + normalized = re.sub(r"\s+", " ", str(text or "")).strip() + if not normalized: + return "" + previous = None + while normalized and normalized != previous: + previous = normalized + normalized = normalized.strip().strip(_QUESTION_SURROUNDING_DELIMITERS).strip() + for pattern in _QUESTION_PREFIX_PATTERNS: + normalized = pattern.sub("", normalized, count=1).strip() + normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() + normalized = normalized.rstrip(_QUESTION_TRAILING_CLOSERS).rstrip() + normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() + chunks = [chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip()] + if len(chunks) > 1: + normalized = "; ".join(chunks) + return normalized.lower() + + +def _wrap_setup_text(setup_type: str, add_setup_tokens: bool) -> str: + setup_type = str(setup_type or "") + if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith(SETUP_END_TOKEN): + return setup_type + if not setup_type or not add_setup_tokens: + return setup_type + return f"{SETUP_START_TOKEN}{setup_type}{SETUP_END_TOKEN}" + + +def _wrap_control_text(control_mode: str, add_control_tokens: bool) -> str: + control_mode = str(control_mode or "") + if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith(CONTROL_END_TOKEN): + return control_mode + if not control_mode or not add_control_tokens: + return control_mode + return f"{CONTROL_START_TOKEN}{control_mode}{CONTROL_END_TOKEN}" + + +def _build_discrete_state_string(state: np.ndarray, num_state_tokens: int) -> str: + if num_state_tokens <= 0: + raise ValueError(f"num_state_tokens must be > 0, got {num_state_tokens}.") + arr = np.asarray(state, dtype=np.float32) + arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0) + arr = np.clip(arr, -1.0, 1.0) + scaled = (arr + 1.0) / 2.0 * float(num_state_tokens - 1) + token_ids = np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1).reshape(-1) + return f"{STATE_START_TOKEN}{''.join(f'{STATE_TOKEN_PREFIX}{int(token_id)}>' for token_id in token_ids)}{STATE_END_TOKEN}" + + +def _build_robot_text( + *, + task: str, + discrete_state_string: str, + setup_type: str, + control_mode: str, + add_setup_tokens: bool, + add_control_tokens: bool, + num_images: int, +) -> str: + setup_text = _wrap_setup_text(setup_type, add_setup_tokens=add_setup_tokens) + control_text = _wrap_control_text(control_mode, add_control_tokens=add_control_tokens) + state_clause = ( + f" The current state of the robot is {discrete_state_string}." if discrete_state_string else "" + ) + prompt = ( + f"The task is to {task}. The setup is {setup_text}.{state_clause} " + f"The expected control mode is {control_text}. Given these, what action should the robot take to complete the task?" + ) + if num_images <= 0: + image_prefix = "" + elif num_images == 1: + image_prefix = "<|image|>" + else: + image_prefix = "".join(f"Image {idx + 1}<|image|>" for idx in range(num_images)) + return f"{image_prefix}<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{ACTION_OUTPUT_TOKEN}" + + +def _as_text_list(value: Any, batch_size: int) -> list[str]: + if value is None: + return [""] * batch_size + if isinstance(value, str): + return [value] * batch_size + if torch.is_tensor(value): + if value.ndim == 0: + return [str(value.item())] * batch_size + flat = value.detach().cpu().reshape(-1).tolist() + texts = [str(item) for item in flat] + elif isinstance(value, np.ndarray): + if value.ndim == 0: + return [str(value.item())] * batch_size + texts = [str(item) for item in value.reshape(-1).tolist()] + elif isinstance(value, (list, tuple)): + texts = [str(item) for item in value] + else: + texts = [str(value)] + if len(texts) == batch_size: + return texts + if len(texts) == 1: + return texts * batch_size + raise ValueError(f"Expected {batch_size} task strings, got {len(texts)}.") + + +def _tokenize_discrete_action(action: np.ndarray, processor: Any) -> list[int]: + arr = np.asarray(action, dtype=np.float32) + if arr.ndim == 2: + arr = arr[None, :, :] + elif arr.ndim == 1: + arr = arr[None, None, :] + tokens_out = processor(arr) + if isinstance(tokens_out, dict): + tokens_out = tokens_out.get("input_ids", next(iter(tokens_out.values()))) + if isinstance(tokens_out, np.ndarray): + tokens_out = tokens_out.tolist() + if torch.is_tensor(tokens_out): + tokens_out = tokens_out.detach().cpu().tolist() + if not isinstance(tokens_out, list): + raise TypeError(f"Unexpected discrete action tokenizer output type: {type(tokens_out)}") + if tokens_out and isinstance(tokens_out[0], (list, tuple, np.ndarray)): + tokens_out = tokens_out[0] + return [int(token_id) for token_id in tokens_out] + + +def _build_discrete_action_string(action: np.ndarray, processor: Any) -> str: + token_ids = _tokenize_discrete_action(action, processor) + pieces = "".join(f"{ACTION_TOKEN_PREFIX}{int(token_id)}>" for token_id in token_ids) + return f"{ACTION_START_TOKEN}{pieces}{ACTION_END_TOKEN}" + + +def _single_token_id(tokenizer: Any, token: str) -> int: + token_ids = tokenizer.encode(token, add_special_tokens=False) + if len(token_ids) != 1: + raise ValueError(f"MolmoAct2 token {token!r} must encode to one token, got {token_ids}.") + return int(token_ids[0]) + + +def _flatten_feature_names(raw_names: Any) -> list[str] | None: + if raw_names is None: + return None + if isinstance(raw_names, dict): + names: list[str] = [] + for value in raw_names.values(): + if isinstance(value, (list, tuple)): + names.extend(str(item) for item in value) + elif value is not None: + names.append(str(value)) + return names or None + if isinstance(raw_names, (list, tuple)): + names = [str(item) for item in raw_names] + return names or None + return [str(raw_names)] + + +def _feature_dim(stats: dict[str, Any] | None) -> int | None: + if not isinstance(stats, dict): + return None + for key in ("mean", "std", "min", "max", "q01", "q99", "q10", "q90", "mask"): + value = stats.get(key) + if value is None: + continue + if torch.is_tensor(value): + return int(value.shape[-1]) if value.ndim > 0 else None + arr = np.asarray(value) + return int(arr.shape[-1]) if arr.ndim > 0 else None + return None + + +def _stats_array(value: Any) -> np.ndarray | None: + if value is None: + return None + if torch.is_tensor(value): + return value.detach().cpu().numpy() if value.ndim > 0 else None + arr = np.asarray(value) + return arr if arr.ndim > 0 else None + + +def _validate_masked_passthrough_stats(feature_stats: dict[str, Any], mask: list[bool], key: str) -> None: + min_values = _stats_array(feature_stats.get("min")) + max_values = _stats_array(feature_stats.get("max")) + if min_values is None or max_values is None: + return + + mask_array = np.asarray(mask, dtype=bool) + if ( + mask_array.ndim != 1 + or min_values.shape[-1] != mask_array.shape[0] + or max_values.shape[-1] != mask_array.shape[0] + or not bool((~mask_array).any()) + ): + return + + passthrough_min = min_values[..., ~mask_array] + passthrough_max = max_values[..., ~mask_array] + if bool(((passthrough_min < -1.0) | (passthrough_max > 1.0)).any()): + raise ValueError( + f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True." + ) + + +def _feature_names_from_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None: + if dataset_meta is None: + return None + + root = getattr(dataset_meta, "root", None) + candidate_roots = [] + if root is not None: + repo_id = str(getattr(dataset_meta, "repo_id", "") or "").strip() + if repo_id: + candidate_roots.append(Path(root) / repo_id) + candidate_roots.append(Path(root)) + for candidate_root in candidate_roots: + info_path = candidate_root / "meta" / "info.json" + if info_path.exists(): + try: + with info_path.open("r", encoding="utf-8") as f: + info = json.load(f) + names = _flatten_feature_names((info.get("features") or {}).get(feature_key, {}).get("names")) + if names: + return names + except (OSError, json.JSONDecodeError, AttributeError): + pass + + for container in ( + getattr(getattr(dataset_meta, "info", None), "features", None), + getattr(dataset_meta, "features", None), + ): + if not isinstance(container, dict): + continue + feature = container.get(feature_key) + if not isinstance(feature, dict): + continue + names = _flatten_feature_names(feature.get("names")) + if names: + return names + return None + + +def _add_gripper_masks_to_stats( + dataset_stats: dict[str, dict[str, Any]] | None, + dataset_meta: Any | None, + *, + normalize_gripper: bool, + dataset_feature_names: dict[str, Any] | None = None, +) -> dict[str, dict[str, Any]] | None: + if not dataset_stats: + return dataset_stats + + stats = deepcopy(dataset_stats) + for key in (ACTION, OBS_STATE): + feature_stats = stats.get(key) + if not isinstance(feature_stats, dict): + continue + dim = _feature_dim(feature_stats) + if dim is None: + continue + + if normalize_gripper: + feature_stats["mask"] = [True] * dim + continue + + names = _flatten_feature_names((dataset_feature_names or {}).get(key)) + if names is None: + names = _feature_names_from_meta(dataset_meta, key) + if names is None: + names = _flatten_feature_names(feature_stats.get("names")) + if names is None: + continue + if len(names) != dim: + continue + mask = ["gripper" not in name.lower() for name in names] + _validate_masked_passthrough_stats(feature_stats, mask, key) + feature_stats["mask"] = mask + return stats + + +def _normalization_masks_from_stats( + dataset_stats: dict[str, dict[str, Any]] | None, +) -> dict[str, list[bool]]: + masks: dict[str, list[bool]] = {} + for key in (ACTION, OBS_STATE): + feature_stats = (dataset_stats or {}).get(key) + if not isinstance(feature_stats, dict): + continue + mask = feature_stats.get("mask") + if isinstance(mask, Tensor): + mask = mask.detach().cpu().tolist() + if isinstance(mask, list) and all(isinstance(value, bool) for value in mask): + masks[key] = mask + return masks + + +class _MolmoAct2MaskedNormalizationMixin: + @staticmethod + def _broadcast_feature_mask(mask: Tensor, tensor: Tensor) -> Tensor | None: + mask = mask.to(device=tensor.device, dtype=torch.bool) + if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]: + return None + while mask.ndim < tensor.ndim: + mask = mask.unsqueeze(0) + return mask + + @staticmethod + def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None: + passthrough_mask = ~mask.expand_as(tensor) + if not bool(passthrough_mask.any()): + return + passthrough_values = tensor[passthrough_mask] + if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()): + raise ValueError( + f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True." + ) + + def _apply_transform( + self, tensor: Tensor, key: str, feature_type: Any, *, inverse: bool = False + ) -> Tensor: + transformed = super()._apply_transform(tensor, key, feature_type, inverse=inverse) + stats = getattr(self, "_tensor_stats", {}).get(key, {}) + mask = stats.get("mask") if isinstance(stats, dict) else None + if mask is None: + return transformed + mask = self._broadcast_feature_mask(mask, tensor) + if mask is None: + return transformed + if not inverse: + self._validate_masked_passthrough_range(tensor, mask, key) + return torch.where(mask, transformed, tensor) + + +@ProcessorStepRegistry.register(name="molmoact2_masked_normalizer") +@dataclass +class MolmoAct2MaskedNormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixin, NormalizerProcessorStep): + pass + + +@ProcessorStepRegistry.register(name="molmoact2_masked_unnormalizer") +@dataclass +class MolmoAct2MaskedUnnormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixin, UnnormalizerProcessorStep): + pass + + +@ProcessorStepRegistry.register(name="molmoact2_clamp_normalized") +@dataclass +class MolmoAct2ClampNormalizedProcessorStep(ProcessorStep): + """Clamp q01/q99-normalized state and action to the range used by the old trainer.""" + + normalization_masks: dict[str, list[bool]] | None = None + + @staticmethod + def _broadcast_feature_mask(mask: list[bool], tensor: Tensor) -> Tensor | None: + tensor_mask = torch.tensor(mask, device=tensor.device, dtype=torch.bool) + if tensor_mask.ndim != 1 or tensor.shape[-1] != tensor_mask.shape[0]: + return None + while tensor_mask.ndim < tensor.ndim: + tensor_mask = tensor_mask.unsqueeze(0) + return tensor_mask + + @staticmethod + def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None: + passthrough_mask = ~mask.expand_as(tensor) + if not bool(passthrough_mask.any()): + return + passthrough_values = tensor[passthrough_mask] + if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()): + raise ValueError( + f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True." + ) + + def _clamp_tensor(self, tensor: Tensor, key: str) -> Tensor: + mask = (self.normalization_masks or {}).get(key) + if mask is None: + return tensor.clamp(-1.0, 1.0) + tensor_mask = self._broadcast_feature_mask(mask, tensor) + if tensor_mask is None: + return tensor.clamp(-1.0, 1.0) + self._validate_masked_passthrough_range(tensor, tensor_mask, key) + return torch.where(tensor_mask, tensor.clamp(-1.0, 1.0), tensor) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict) and OBS_STATE in observation: + observation = observation.copy() + observation[OBS_STATE] = self._clamp_tensor(torch.as_tensor(observation[OBS_STATE]), OBS_STATE) + transition[TransitionKey.OBSERVATION] = observation + action = transition.get(TransitionKey.ACTION) + if action is not None: + transition[TransitionKey.ACTION] = self._clamp_tensor(torch.as_tensor(action), ACTION) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register(name="molmoact2_pack_inputs") +@dataclass +class MolmoAct2PackInputsProcessorStep(ProcessorStep): + checkpoint_path: str + checkpoint_revision: str | None = None + checkpoint_force_download: bool = False + action_mode: str = "both" + discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer" + image_keys: list[str] = field(default_factory=list) + allow_image_key_fallback: bool = False + setup_type: str = "" + control_mode: str = "" + normalize_language: bool = True + add_setup_tokens: bool = True + add_control_tokens: bool = True + num_state_tokens: int = 256 + max_sequence_length: int | None = None + chunk_size: int = 30 + max_action_dim: int = 32 + env_action_dim: int | None = None + + def __post_init__(self) -> None: + require_package("transformers", extra="molmoact2") + + checkpoint_location = _resolve_checkpoint_location( + self.checkpoint_path, + revision=self.checkpoint_revision, + force_download=bool(self.checkpoint_force_download), + ) + self.processor = _load_local_molmoact2_processor(checkpoint_location) + self.action_processor = None + if self.action_mode in {"discrete", "both"}: + require_package("scipy", extra="molmoact2") + if UniversalActionProcessor is None: + raise RuntimeError("transformers and scipy are required to load MolmoAct2 action tokenizer.") + self.action_processor = UniversalActionProcessor.from_pretrained_local( + self.discrete_action_tokenizer, + ) + self._action_start_id = _single_token_id(self.processor.tokenizer, ACTION_START_TOKEN) + self._action_end_id = _single_token_id(self.processor.tokenizer, ACTION_END_TOKEN) + self._eos_token = self.processor.tokenizer.eos_token or "" + self._eos_token_id = self.processor.tokenizer.eos_token_id + + def get_config(self) -> dict[str, Any]: + return { + "checkpoint_path": self.checkpoint_path, + "checkpoint_revision": self.checkpoint_revision, + "checkpoint_force_download": self.checkpoint_force_download, + "action_mode": self.action_mode, + "discrete_action_tokenizer": self.discrete_action_tokenizer, + "image_keys": list(self.image_keys), + "allow_image_key_fallback": self.allow_image_key_fallback, + "setup_type": self.setup_type, + "control_mode": self.control_mode, + "normalize_language": self.normalize_language, + "add_setup_tokens": self.add_setup_tokens, + "add_control_tokens": self.add_control_tokens, + "num_state_tokens": self.num_state_tokens, + "max_sequence_length": self.max_sequence_length, + "chunk_size": self.chunk_size, + "max_action_dim": self.max_action_dim, + "env_action_dim": self.env_action_dim, + } + + def _resolve_max_sequence_length( + self, + *, + num_images: int, + state_dim: int, + action_dim: int, + action_horizon: int, + include_discrete_action: bool, + ) -> int: + if self.max_sequence_length is not None: + return int(self.max_sequence_length) + return infer_molmoact2_max_sequence_length( + num_images=num_images, + state_dim=state_dim, + action_dim=action_dim, + action_horizon=action_horizon, + include_discrete_action=include_discrete_action, + ) + + def _batch_size(self, observation: dict[str, Any], action: Tensor | None) -> int: + if action is not None: + return int(action.shape[0]) + state = observation.get(OBS_STATE) + if torch.is_tensor(state) or isinstance(state, np.ndarray): + return int(state.shape[0]) if getattr(state, "ndim", 0) > 1 else 1 + for key in self._resolve_image_keys(observation): + value = observation[key] + if torch.is_tensor(value) or isinstance(value, np.ndarray): + return int(value.shape[0]) if getattr(value, "ndim", 0) == 4 else 1 + return 1 + + @staticmethod + def _observation_image_keys(observation: dict[str, Any]) -> list[str]: + keys = [key for key in observation if str(key).startswith(f"{OBS_IMAGES}.")] + if not keys: + keys = [key for key in observation if str(key).startswith("observation.image")] + return sorted(keys) + + def _resolve_image_keys(self, observation: dict[str, Any]) -> list[str]: + if self.image_keys: + missing = [key for key in self.image_keys if key not in observation] + if missing: + fallback_keys = self._observation_image_keys(observation) + if self.allow_image_key_fallback and fallback_keys: + return fallback_keys + raise ValueError(f"MolmoAct2 image_keys missing from observation: {missing}.") + return list(self.image_keys) + keys = self._observation_image_keys(observation) + if not keys: + raise ValueError("MolmoAct2 requires at least one image observation.") + return sorted(keys) + + def _extract_images(self, observation: dict[str, Any], batch_size: int) -> list[list[np.ndarray]]: + images_by_example: list[list[np.ndarray]] = [[] for _ in range(batch_size)] + for key in self._resolve_image_keys(observation): + value = observation[key] + for batch_idx in range(batch_size): + item = value + if (torch.is_tensor(value) or isinstance(value, np.ndarray)) and getattr( + value, "ndim", 0 + ) >= 4: + item = value[batch_idx] + images_by_example[batch_idx].append(_normalize_image(item)) + return images_by_example + + def _extract_state(self, observation: dict[str, Any], batch_size: int) -> Tensor: + if OBS_STATE not in observation: + raise ValueError("MolmoAct2 requires observation.state for discrete state prompting.") + state = torch.as_tensor(observation[OBS_STATE], dtype=torch.float32) + if state.ndim == 1: + state = state.unsqueeze(0) + if int(state.shape[0]) != batch_size: + raise ValueError(f"State batch size {state.shape[0]} does not match batch size {batch_size}.") + return state + + def _pad_action(self, action: Tensor, action_is_pad: Any | None) -> tuple[Tensor, Tensor, Tensor]: + if action.ndim == 2: + action = action.unsqueeze(1) + if action.ndim != 3: + raise ValueError(f"MolmoAct2 expected action shape [B, T, D], got {tuple(action.shape)}.") + if action.shape[-1] > self.max_action_dim: + raise ValueError( + f"Action dim {action.shape[-1]} exceeds MolmoAct2 max_action_dim={self.max_action_dim}." + ) + padded = torch.zeros( + (*action.shape[:-1], self.max_action_dim), + device=action.device, + dtype=torch.float32, + ) + padded[..., : action.shape[-1]] = action.to(dtype=torch.float32) + action_dim_is_pad = torch.ones( + (action.shape[0], self.max_action_dim), device=action.device, dtype=torch.bool + ) + action_dim_is_pad[:, : action.shape[-1]] = False + if action_is_pad is None: + action_horizon_is_pad = torch.zeros(action.shape[:2], device=action.device, dtype=torch.bool) + else: + action_horizon_is_pad = torch.as_tensor(action_is_pad, device=action.device, dtype=torch.bool) + if action_horizon_is_pad.ndim == 1: + action_horizon_is_pad = action_horizon_is_pad.unsqueeze(0) + if tuple(action_horizon_is_pad.shape) != tuple(action.shape[:2]): + raise ValueError( + "action_is_pad must match action horizon shape: " + f"got {tuple(action_horizon_is_pad.shape)} for action {tuple(action.shape)}." + ) + return padded, action_horizon_is_pad, action_dim_is_pad + + def _build_labels(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + labels = torch.full_like(input_ids, -100) + for batch_idx in range(input_ids.shape[0]): + valid = attention_mask[batch_idx].to(dtype=torch.bool) + row = input_ids[batch_idx] + starts = (row == self._action_start_id).nonzero(as_tuple=False).flatten().tolist() + ends = (row == self._action_end_id).nonzero(as_tuple=False).flatten().tolist() + end_ptr = 0 + for start in starts: + while end_ptr < len(ends) and ends[end_ptr] < start: + end_ptr += 1 + if end_ptr >= len(ends): + raise ValueError( + "Found without matching in MolmoAct2 labels." + ) + end = int(ends[end_ptr]) + label_end = end + 1 + if ( + self._eos_token_id is not None + and label_end < int(row.shape[0]) + and int(row[label_end]) == int(self._eos_token_id) + ): + label_end += 1 + labels[batch_idx, start:label_end] = row[start:label_end] + end_ptr += 1 + if not starts: + raise ValueError("No discrete action span found in MolmoAct2 training text.") + labels[batch_idx] = torch.where( + valid, labels[batch_idx], torch.full_like(labels[batch_idx], -100) + ) + return labels + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + observation = transition.get(TransitionKey.OBSERVATION) or {} + if not isinstance(observation, dict): + raise ValueError("MolmoAct2 expected an observation dictionary.") + complementary = dict(transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) + + raw_action = transition.get(TransitionKey.ACTION) + action = torch.as_tensor(raw_action, dtype=torch.float32) if raw_action is not None else None + batch_size = self._batch_size(observation, action) + state = self._extract_state(observation, batch_size) + images_by_example = self._extract_images(observation, batch_size) + + task_source = complementary.get("task") + if task_source is None: + task_source = observation.get("task") + if task_source is None: + task_source = observation.get("observation.language") + if task_source is None: + task_source = complementary.get("language_instruction") + tasks = _as_text_list(task_source, batch_size) + if self.normalize_language: + tasks = [_normalize_question_text(task) for task in tasks] + complementary["task"] = tasks + + action_padded = None + action_horizon_is_pad = None + action_dim_is_pad = torch.ones((batch_size, self.max_action_dim), dtype=torch.bool) + real_action_dim = int(self.env_action_dim or 0) + if action is not None: + action_is_pad = complementary.get("action_is_pad") + if action_is_pad is None: + action_is_pad = complementary.get("action_horizon_is_pad") + action_padded, action_horizon_is_pad, action_dim_is_pad = self._pad_action(action, action_is_pad) + real_action_dim = int(action.shape[-1]) + elif real_action_dim > 0: + action_dim_is_pad[:, :real_action_dim] = False + + prompt_texts: list[str] = [] + full_texts: list[str] = [] + flat_images: list[np.ndarray] = [] + state_np = state.detach().cpu().numpy() + build_action_labels = action is not None and self.action_mode in {"discrete", "both"} + for batch_idx in range(batch_size): + images = images_by_example[batch_idx] + flat_images.extend(images) + discrete_state = _build_discrete_state_string(state_np[batch_idx], self.num_state_tokens) + prompt = _build_robot_text( + task=tasks[batch_idx], + discrete_state_string=discrete_state, + setup_type=self.setup_type, + control_mode=self.control_mode, + add_setup_tokens=self.add_setup_tokens, + add_control_tokens=self.add_control_tokens, + num_images=len(images), + ) + prompt_texts.append(prompt) + if build_action_labels: + if self.action_processor is None: + raise ValueError("Discrete MolmoAct2 training requires an action tokenizer.") + answer = _build_discrete_action_string( + action[batch_idx].detach().cpu().numpy(), self.action_processor + ) + full_texts.append(f"{prompt}{answer}{self._eos_token}") + else: + full_texts.append(prompt) + + text = full_texts if build_action_labels else prompt_texts + inputs = self.processor(text=text, images=flat_images, return_tensors="pt", padding=True) + if action is None: + action_horizon = self.chunk_size + elif action.ndim == 2: + action_horizon = 1 + else: + action_horizon = int(action.shape[1]) + max_sequence_length = self._resolve_max_sequence_length( + num_images=max((len(images) for images in images_by_example), default=0), + state_dim=int(state.shape[-1]), + action_dim=max(real_action_dim, 1), + action_horizon=action_horizon, + include_discrete_action=build_action_labels, + ) + if int(inputs["input_ids"].shape[1]) > max_sequence_length: + raise ValueError( + f"MolmoAct2 sequence length {int(inputs['input_ids'].shape[1])} exceeds " + f"max_sequence_length={max_sequence_length}." + ) + + if build_action_labels: + inputs["labels"] = self._build_labels(inputs["input_ids"], inputs["attention_mask"]) + + complementary.update(dict(inputs)) + complementary["action_dim_is_pad"] = action_dim_is_pad + if action_horizon_is_pad is not None: + complementary["action_horizon_is_pad"] = action_horizon_is_pad + + if action_padded is not None: + transition[TransitionKey.ACTION] = action_padded + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register(name="molmoact2_clamp_action") +@dataclass +class MolmoAct2ClampActionProcessorStep(ProcessorStep): + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + action = transition.get(TransitionKey.ACTION) + if action is not None: + transition[TransitionKey.ACTION] = torch.as_tensor(action).clamp(-1.0, 1.0) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +def make_molmoact2_pre_post_processors( + config: MolmoAct2Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta: Any | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + env_action_dim = None + if config.output_features and ACTION in config.output_features: + env_action_dim = int(config.output_features[ACTION].shape[0]) + + hf_metadata: dict[str, Any] = {} + if dataset_stats is None and str(config.norm_tag or "").strip(): + dataset_stats, hf_metadata = _load_hf_norm_stats_for_tag( + config.checkpoint_path, + revision=config.checkpoint_revision, + force_download=bool(config.checkpoint_force_download), + norm_tag=config.norm_tag, + ) + + image_keys = list(config.image_keys) + visual_feature_keys = [ + key for key, feature in config.input_features.items() if feature.type == FeatureType.VISUAL + ] + if not image_keys and isinstance(hf_metadata.get("camera_keys"), list): + metadata_image_keys = [str(key) for key in hf_metadata["camera_keys"]] + if not visual_feature_keys or all(key in config.input_features for key in metadata_image_keys): + image_keys = metadata_image_keys + if not image_keys: + image_keys = visual_feature_keys + setup_type = config.setup_type or str(hf_metadata.get("setup_type") or "") + control_mode = config.control_mode or str(hf_metadata.get("control_mode") or "") + chunk_size = int(hf_metadata.get("action_horizon") or config.chunk_size) + + masked_dataset_stats = _add_gripper_masks_to_stats( + dataset_stats, + dataset_meta, + normalize_gripper=config.normalize_gripper, + dataset_feature_names=config.dataset_feature_names, + ) + normalization_masks = _normalization_masks_from_stats(masked_dataset_stats) + + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + MolmoAct2MaskedNormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=masked_dataset_stats, + ), + MolmoAct2ClampNormalizedProcessorStep(normalization_masks=normalization_masks), + MolmoAct2PackInputsProcessorStep( + checkpoint_path=config.checkpoint_path, + checkpoint_revision=config.checkpoint_revision, + checkpoint_force_download=config.checkpoint_force_download, + action_mode=config.action_mode, + discrete_action_tokenizer=config.discrete_action_tokenizer, + image_keys=image_keys, + allow_image_key_fallback=not bool(config.image_keys), + setup_type=setup_type, + control_mode=control_mode, + normalize_language=config.normalize_language, + add_setup_tokens=config.add_setup_tokens, + add_control_tokens=config.add_control_tokens, + num_state_tokens=config.num_state_tokens, + max_sequence_length=config.max_sequence_length, + chunk_size=chunk_size, + max_action_dim=config.expected_max_action_dim, + env_action_dim=env_action_dim, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + MolmoAct2ClampActionProcessorStep(), + MolmoAct2MaskedUnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=masked_dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/tests/policies/molmoact2/test_molmoact2.py b/tests/policies/molmoact2/test_molmoact2.py new file mode 100644 index 000000000..3631bcc9b --- /dev/null +++ b/tests/policies/molmoact2/test_molmoact2.py @@ -0,0 +1,1397 @@ +#!/usr/bin/env python + +# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MolmoAct2's LeRobot policy interface.""" + +# ruff: noqa: E402 + +from __future__ import annotations + +import json +from collections import deque +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +import torch.nn.functional as F # noqa: N812 + +pytest.importorskip("transformers") +pytest.importorskip("scipy") + +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies import get_policy_class, make_policy_config +from lerobot.policies.molmoact2 import ( + configuration_molmoact2 as molmoact2_config, + modeling_molmoact2 as molmoact2_modeling, + processor_molmoact2 as molmoact2_processor, +) +from lerobot.policies.molmoact2.configuration_molmoact2 import ( + MolmoAct2Config, + MolmoAct2CosineDecayWithWarmupSchedulerConfig, + infer_molmoact2_max_sequence_length, +) +from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy +from lerobot.policies.molmoact2.processor_molmoact2 import ( + MolmoAct2ClampNormalizedProcessorStep, + MolmoAct2MaskedNormalizerProcessorStep, + MolmoAct2MaskedUnnormalizerProcessorStep, + MolmoAct2PackInputsProcessorStep, + _add_gripper_masks_to_stats, + _build_discrete_state_string, + _normalize_question_text, + make_molmoact2_pre_post_processors, +) +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.types import TransitionKey +from lerobot.utils.constants import ACTION, OBS_STATE + + +def test_molmoact2_policy_registration(): + cfg = make_policy_config("molmoact2", checkpoint_path="/tmp/not-a-real-checkpoint") + + assert cfg.type == "molmoact2" + assert cfg.action_mode == "both" + assert cfg.normalize_gripper is False + assert cfg.enable_knowledge_insulation is False + assert cfg.freeze_embedding is True + assert cfg.per_episode_seed is False + assert cfg.eval_seed is None + assert cfg.normalize_language is True + assert cfg.get_scheduler_preset().num_decay_steps is None + assert cfg.action_delta_indices == list(range(cfg.chunk_size)) + assert get_policy_class("molmoact2") is MolmoAct2Policy + + +def test_molmoact2_checkpoint_download_ignores_remote_python(monkeypatch): + download_kwargs = {} + + def fake_snapshot_download(**kwargs): + download_kwargs.update(kwargs) + return "/tmp/downloaded-molmoact2" + + monkeypatch.setattr(molmoact2_config, "snapshot_download", fake_snapshot_download) + + checkpoint_location = molmoact2_config._resolve_checkpoint_location("allenai/MolmoAct2") + + assert checkpoint_location == "/tmp/downloaded-molmoact2" + assert download_kwargs["ignore_patterns"] == ["*.py", "*.pyc", "__pycache__/*"] + + +def test_molmoact2_scheduler_decay_steps_auto_match_training_steps(): + param = torch.nn.Parameter(torch.ones(())) + optimizer = torch.optim.AdamW([param], lr=0.001) + config = MolmoAct2CosineDecayWithWarmupSchedulerConfig( + peak_lr=0.01, + decay_lr=0.001, + num_warmup_steps=10, + num_decay_steps=None, + ) + + scheduler = config.build(optimizer, num_training_steps=100) + for _ in range(100): + optimizer.step() + scheduler.step() + + assert scheduler.get_last_lr() == pytest.approx([0.0001]) + + +def test_molmoact2_rollout_generator_uses_eval_seed_per_task(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config(per_episode_seed=True, eval_seed=1000) + policy._rollout_action_generator = None + policy._rollout_task_key = None + policy._rollout_index_for_task = -1 + + policy.reset() + first = policy._rollout_generator_for_inputs( + {"task": ["pick", "pick", "pick"]}, + batch_size=3, + device=torch.device("cpu"), + ) + expected_first = torch.Generator().manual_seed( + MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3) + ) + assert torch.allclose(torch.rand(4, generator=first), torch.rand(4, generator=expected_first)) + + policy.reset() + second = policy._rollout_generator_for_inputs( + {"task": ["pick", "pick", "pick"]}, + batch_size=3, + device=torch.device("cpu"), + ) + expected_second = torch.Generator().manual_seed( + MolmoAct2Policy._combine_rollout_seeds(first_seed=1003, batch_size=3) + ) + assert torch.allclose(torch.rand(4, generator=second), torch.rand(4, generator=expected_second)) + + policy.reset() + new_task = policy._rollout_generator_for_inputs( + {"task": ["place", "place", "place"]}, + batch_size=3, + device=torch.device("cpu"), + ) + expected_new_task = torch.Generator().manual_seed( + MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3) + ) + assert torch.allclose(torch.rand(4, generator=new_task), torch.rand(4, generator=expected_new_task)) + + +def test_molmoact2_gripper_mask_uses_feature_names(tmp_path): + meta_dir = tmp_path / "meta" + meta_dir.mkdir() + (meta_dir / "info.json").write_text( + json.dumps( + { + "features": { + ACTION: {"names": {"motors": ["x", "gripper"]}}, + OBS_STATE: {"names": {"motors": ["joint", "gripper"]}}, + } + } + ), + encoding="utf-8", + ) + dataset_meta = SimpleNamespace(root=tmp_path) + stats = { + ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}, + OBS_STATE: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}, + } + + masked_stats = _add_gripper_masks_to_stats(stats, dataset_meta, normalize_gripper=False) + + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, False] + assert masked_stats[OBS_STATE]["mask"] == [True, False] + + features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)), + } + norm_map = { + FeatureType.ACTION: NormalizationMode.QUANTILES, + FeatureType.STATE: NormalizationMode.QUANTILES, + } + transition = { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 0.7]])}, + TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]), + } + normalizer = MolmoAct2MaskedNormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=masked_stats, + ) + normalized = normalizer(transition) + + assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 0.7]])) + assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, -0.7]])) + + with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"): + normalizer( + { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])}, + TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]), + } + ) + + unnormalizer = MolmoAct2MaskedUnnormalizerProcessorStep( + features={ACTION: features[ACTION]}, + norm_map=norm_map, + stats=masked_stats, + ) + unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, -0.7]])}) + + assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, -0.7]])) + + +def test_molmoact2_gripper_mask_validates_dataset_stats(tmp_path): + meta_dir = tmp_path / "meta" + meta_dir.mkdir() + (meta_dir / "info.json").write_text( + json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}), + encoding="utf-8", + ) + stats = { + ACTION: { + "min": [-0.5, -2.0], + "max": [0.5, 0.5], + } + } + + with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"): + _add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=False) + + masked_stats = _add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=True) + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, True] + + +def test_molmoact2_clamp_normalized_respects_masked_gripper_dims(): + step = MolmoAct2ClampNormalizedProcessorStep( + normalization_masks={ + ACTION: [True, False], + OBS_STATE: [True, False], + } + ) + transition = { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[-2.0, 0.8]])}, + TransitionKey.ACTION: torch.tensor([[2.0, -0.8]]), + } + + clamped = step(transition) + + assert torch.equal(clamped[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[-1.0, 0.8]])) + assert torch.equal(clamped[TransitionKey.ACTION], torch.tensor([[1.0, -0.8]])) + + with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"): + step({TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[0.0, 1.2]])}}) + + +def test_molmoact2_normalize_gripper_true_keeps_all_dims_normalized(tmp_path): + meta_dir = tmp_path / "meta" + meta_dir.mkdir() + (meta_dir / "info.json").write_text( + json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}), + encoding="utf-8", + ) + stats = {ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}} + + masked_stats = _add_gripper_masks_to_stats( + stats, + SimpleNamespace(root=tmp_path), + normalize_gripper=True, + ) + + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, True] + + +def test_molmoact2_uses_supplied_stats_with_repo_scoped_names(tmp_path): + repo_root = tmp_path / "test-org" / "libero" + (repo_root / "meta").mkdir(parents=True) + (repo_root / "meta" / "info.json").write_text( + json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}), + encoding="utf-8", + ) + base_stats = {ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}} + + masked_stats = _add_gripper_masks_to_stats( + base_stats, + SimpleNamespace(root=tmp_path, repo_id="test-org/libero"), + normalize_gripper=False, + ) + + assert masked_stats is not None + assert masked_stats[ACTION]["q01"] == [0.0, 0.0] + assert masked_stats[ACTION]["mask"] == [True, False] + + +def test_molmoact2_uses_config_feature_names_without_dataset_meta(): + base_stats = {ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}} + + masked_stats = _add_gripper_masks_to_stats( + base_stats, + None, + normalize_gripper=False, + dataset_feature_names={ACTION: ["x", "gripper"]}, + ) + + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, False] + + +def test_molmoact2_processor_uses_available_visual_features_over_missing_metadata_keys(monkeypatch): + monkeypatch.setattr( + molmoact2_processor, + "_load_hf_norm_stats_for_tag", + lambda *args, **kwargs: ( + {}, + {"camera_keys": ["observation.images.image", "observation.images.wrist_image"]}, + ), + ) + monkeypatch.setattr(MolmoAct2PackInputsProcessorStep, "__post_init__", lambda self: None) + cfg = MolmoAct2Config( + checkpoint_path="/tmp/not-a-real-checkpoint", + norm_tag="libero", + input_features={ + "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + "observation.images.image2": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)), + }, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))}, + ) + + preprocessor, _ = make_molmoact2_pre_post_processors(cfg) + pack_step = next( + step for step in preprocessor.steps if isinstance(step, MolmoAct2PackInputsProcessorStep) + ) + + assert pack_step.image_keys == ["observation.images.image", "observation.images.image2"] + assert pack_step.allow_image_key_fallback is True + + +def test_molmoact2_metadata_image_keys_can_fall_back_to_observation_keys(): + step = object.__new__(MolmoAct2PackInputsProcessorStep) + step.image_keys = ["observation.images.image", "observation.images.wrist_image"] + step.allow_image_key_fallback = True + observation = { + "observation.images.image": torch.zeros(3, 4, 4), + "observation.images.image2": torch.zeros(3, 4, 4), + } + + assert step._resolve_image_keys(observation) == ["observation.images.image", "observation.images.image2"] + + +def test_molmoact2_explicit_image_keys_stay_strict(): + step = object.__new__(MolmoAct2PackInputsProcessorStep) + step.image_keys = ["observation.images.image", "observation.images.wrist_image"] + step.allow_image_key_fallback = False + observation = { + "observation.images.image": torch.zeros(3, 4, 4), + "observation.images.image2": torch.zeros(3, 4, 4), + } + + with pytest.raises(ValueError, match="wrist_image"): + step._resolve_image_keys(observation) + + +def test_enable_lora_vlm_builds_policy_local_peft_config(): + pytest.importorskip("peft") + policy_cfg = MolmoAct2Config( + checkpoint_path="/tmp/not-a-real-checkpoint", + device="cpu", + enable_lora_vlm=True, + lora_rank=64, + push_to_hub=False, + ) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = policy_cfg + + peft_config = policy._build_inner_lora_config() + + assert peft_config.r == 64 + assert peft_config.target_modules == policy._get_inner_peft_targets()["target_modules"] + assert not policy_cfg.use_peft + + +def test_cuda_graph_managers_are_inference_only(): + class DummyManager: + def __init__(self): + self.enabled = None + + def set_enabled(self, enabled): + self.enabled = enabled + + class DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.action_cuda_graph_manager = DummyManager() + + def _require_action_expert(self): + return torch.nn.Linear(1, 1) + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = DummyBackbone() + self.depth_decode_cuda_graph_manager = DummyManager() + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(train_action_expert_only=False, enable_inference_cuda_graph=True) + policy.model = DummyModel() + + policy.train() + assert policy.model.model.action_cuda_graph_manager.enabled is False + assert policy.model.depth_decode_cuda_graph_manager.enabled is False + + policy.eval() + assert policy.model.model.action_cuda_graph_manager.enabled is True + assert policy.model.depth_decode_cuda_graph_manager.enabled is True + + policy.config.enable_inference_cuda_graph = False + policy.eval() + assert policy.model.model.action_cuda_graph_manager.enabled is False + assert policy.model.depth_decode_cuda_graph_manager.enabled is False + + +def test_lora_action_expert_target_is_opt_in(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + lora_rank=64, + lora_alpha=16, + lora_dropout=0.05, + lora_bias="none", + enable_lora_action_expert=False, + ) + + targets = policy._get_default_peft_targets()["target_modules"] + + assert "transformer|vision_backbone" in targets + assert "action_expert" not in targets + + policy.config.enable_lora_action_expert = True + targets = policy._get_default_peft_targets()["target_modules"] + + assert "action_expert" in targets + assert "state_encoder" not in targets + assert "state_norm" not in targets + assert "kv_proj" not in targets + + +def test_enable_lora_vlm_wraps_loaded_hf_model_locally(): + pytest.importorskip("peft") + + class DummyInnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.transformer = torch.nn.Module() + self.transformer.wq = torch.nn.Linear(2, 2) + self.action_expert = torch.nn.Module() + self.action_expert.action_embed = torch.nn.Linear(2, 2) + + class DummyHFModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = {} + self.model = DummyInnerModel() + + def forward(self, x): + return self.model.transformer.wq(x) + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + checkpoint_path="/tmp/base", + lora_rank=2, + lora_alpha=4, + lora_dropout=0.0, + lora_bias="none", + enable_lora_action_expert=False, + train_action_expert_only=False, + enable_inference_cuda_graph=False, + ) + policy.model = DummyHFModel() + + policy._apply_lora_adapters() + + assert policy._backbone() is policy.model.base_model.model.model + trainable = [name for name, param in policy.named_parameters() if param.requires_grad] + assert trainable + assert any("lora_" in name for name in trainable) + assert any("action_expert.action_embed" in name and "lora_" not in name for name in trainable) + assert policy.model(torch.ones(1, 2)).shape == (1, 2) + + +def test_lora_vlm_unfreezes_action_expert_base_weights(): + class DummyInnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.transformer = torch.nn.Module() + self.transformer.wq = torch.nn.Linear(2, 2) + self.action_expert = torch.nn.Module() + self.action_expert.action_embed = torch.nn.Linear(2, 2) + + class DummyHFModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = DummyInnerModel() + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.model = DummyHFModel() + + for param in policy.parameters(): + param.requires_grad_(False) + policy._unfreeze_action_expert_parameters() + + trainable = [name for name, param in policy.named_parameters() if param.requires_grad] + assert trainable + assert all("action_expert" in name for name in trainable) + + +def test_train_action_expert_only_requires_continuous_action_mode(): + with pytest.raises(ValueError, match="requires action_mode='continuous'"): + MolmoAct2Config(action_mode="both", train_action_expert_only=True) + + with pytest.raises(ValueError, match="incompatible with enable_lora_vlm"): + MolmoAct2Config(action_mode="continuous", train_action_expert_only=True, enable_lora_vlm=True) + + cfg = MolmoAct2Config(action_mode="continuous", train_action_expert_only=True) + assert cfg.train_action_expert_only + + +def test_molmoact2_sequence_length_is_inferred_from_fixed_token_budget(): + cfg = MolmoAct2Config( + action_mode="both", + chunk_size=10, + n_action_steps=10, + image_keys=["observation.images.image", "observation.images.wrist_image"], + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))}, + ) + + assert cfg.max_sequence_length is None + assert cfg.inferred_max_sequence_length() == 640 + assert cfg.inferred_max_sequence_length(include_discrete_action=False) == 576 + assert ( + infer_molmoact2_max_sequence_length( + num_images=2, + state_dim=8, + action_dim=7, + action_horizon=30, + include_discrete_action=True, + ) + == 768 + ) + + +def test_molmoact2_sequence_length_override_is_preserved(): + cfg = MolmoAct2Config(max_sequence_length=1024) + + assert cfg.inferred_max_sequence_length(num_images=2, state_dim=8, action_dim=7) == 1024 + + +def test_train_action_expert_only_freezes_non_action_expert_params(): + class DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.transformer = torch.nn.Linear(2, 2) + self.vision_backbone = torch.nn.Linear(2, 2) + self.action_expert = torch.nn.Linear(2, 2) + + def _require_action_expert(self): + return self.action_expert + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = DummyBackbone() + self.lm_head = torch.nn.Linear(2, 2) + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(train_action_expert_only=True) + policy.model = DummyModel() + + policy._freeze_non_action_expert_parameters() + policy.train() + + assert policy.model.model.action_expert.training + assert not policy.model.training + assert not policy.model.model.transformer.training + assert all(param.requires_grad for param in policy.model.model.action_expert.parameters()) + assert not any(param.requires_grad for param in policy.model.model.transformer.parameters()) + assert not any(param.requires_grad for param in policy.model.model.vision_backbone.parameters()) + assert not any(param.requires_grad for param in policy.model.lm_head.parameters()) + + +def test_load_hf_model_accepts_max_action_horizon_schema(monkeypatch): + class DummyLoadedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + max_action_dim=32, + max_action_horizon=30, + action_mode="both", + add_action_expert=True, + ) + self.model = torch.nn.Module() + self.embed_tokens = torch.nn.Embedding(4, 4) + self.lm_head = torch.nn.Linear(4, 4, bias=False) + + def get_input_embeddings(self): + return self.embed_tokens + + loaded_model = DummyLoadedModel() + resolved_kwargs = {} + + def fake_resolve_checkpoint_location(checkpoint_path, **kwargs): + resolved_kwargs.update(kwargs) + return checkpoint_path + + config_kwargs = {} + model_kwargs = {} + + class DummyHFConfig: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args + config_kwargs.update(kwargs) + return SimpleNamespace() + + class DummyMolmoAct2ForConditionalGeneration: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args + model_kwargs.update(kwargs) + return loaded_model + + monkeypatch.setattr(molmoact2_modeling, "_resolve_checkpoint_location", fake_resolve_checkpoint_location) + monkeypatch.setattr(molmoact2_modeling, "HFMolmoAct2Config", DummyHFConfig) + monkeypatch.setattr( + molmoact2_modeling, + "MolmoAct2ForConditionalGeneration", + DummyMolmoAct2ForConditionalGeneration, + ) + monkeypatch.setattr(molmoact2_modeling, "_strict_load_safetensors_weights", lambda *args: None) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + checkpoint_path="/tmp/new-schema-checkpoint", + checkpoint_revision="main", + checkpoint_force_download=True, + chunk_size=10, + n_action_steps=10, + action_mode="both", + ) + + policy._load_hf_model() + + assert policy.model is loaded_model + assert not hasattr(policy.model.config, "action_horizon") + assert policy.model.config.max_action_horizon == 10 + assert policy._generation_action_horizon() == 10 + assert resolved_kwargs == {"revision": "main", "force_download": True} + assert "trust_remote_code" not in config_kwargs + assert "trust_remote_code" not in model_kwargs + + +def test_load_hf_model_chunk_size_overrides_larger_than_checkpoint_horizon(monkeypatch): + class DummyLoadedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + max_action_dim=32, + max_action_horizon=10, + action_mode="both", + add_action_expert=True, + ) + self.model = torch.nn.Module() + self.embed_tokens = torch.nn.Embedding(4, 4) + self.lm_head = torch.nn.Linear(4, 4, bias=False) + + def get_input_embeddings(self): + return self.embed_tokens + + loaded_model = DummyLoadedModel() + monkeypatch.setattr( + molmoact2_modeling, + "_resolve_checkpoint_location", + lambda checkpoint_path, **kwargs: checkpoint_path, + ) + + class DummyHFConfig: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args, kwargs + return SimpleNamespace() + + class DummyMolmoAct2ForConditionalGeneration: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args, kwargs + return loaded_model + + monkeypatch.setattr(molmoact2_modeling, "HFMolmoAct2Config", DummyHFConfig) + monkeypatch.setattr( + molmoact2_modeling, + "MolmoAct2ForConditionalGeneration", + DummyMolmoAct2ForConditionalGeneration, + ) + monkeypatch.setattr(molmoact2_modeling, "_strict_load_safetensors_weights", lambda *args: None) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + checkpoint_path="/tmp/new-schema-checkpoint", + chunk_size=30, + n_action_steps=30, + action_mode="both", + ) + + policy._load_hf_model() + + assert policy.model.config.max_action_horizon == 30 + assert policy._generation_action_horizon() == 30 + + +def test_load_hf_model_rejects_legacy_action_horizon_schema(monkeypatch): + class DummyLoadedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + max_action_dim=32, + action_horizon=30, + action_mode="both", + add_action_expert=True, + ) + self.model = torch.nn.Module() + + monkeypatch.setattr( + molmoact2_modeling, + "_resolve_checkpoint_location", + lambda checkpoint_path, **kwargs: checkpoint_path, + ) + + class DummyHFConfig: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args, kwargs + return SimpleNamespace() + + class DummyMolmoAct2ForConditionalGeneration: + @classmethod + def from_pretrained(cls, *args, **kwargs): + del args, kwargs + return DummyLoadedModel() + + monkeypatch.setattr(molmoact2_modeling, "HFMolmoAct2Config", DummyHFConfig) + monkeypatch.setattr( + molmoact2_modeling, + "MolmoAct2ForConditionalGeneration", + DummyMolmoAct2ForConditionalGeneration, + ) + monkeypatch.setattr(molmoact2_modeling, "_strict_load_safetensors_weights", lambda *args: None) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + checkpoint_path="/tmp/legacy-schema-checkpoint", + chunk_size=10, + n_action_steps=10, + action_mode="both", + ) + + with pytest.raises(ValueError, match="max_action_horizon"): + policy._load_hf_model() + + +def test_rtc_processor_initialization_and_select_action_guard(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(rtc_config=RTCConfig(enabled=True)) + + policy.init_rtc_processor() + + assert policy.rtc_processor is not None + with pytest.raises(AssertionError, match="RTC is not supported for select_action"): + policy.select_action({}) + + +def test_select_action_uses_single_full_batch_queue(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(rtc_config=None, n_action_steps=2) + policy._action_queue = deque(maxlen=2) + calls = 0 + + def predict_action_chunk(batch, **kwargs): + nonlocal calls + del batch, kwargs + calls += 1 + return torch.tensor( + [ + [[1.0], [2.0]], + [[3.0], [4.0]], + ] + ) + + policy.predict_action_chunk = predict_action_chunk + + first = policy.select_action({}) + second = policy.select_action({}) + + assert calls == 1 + assert torch.equal(first, torch.tensor([[1.0], [3.0]])) + assert torch.equal(second, torch.tensor([[2.0], [4.0]])) + + +def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config(action_mode="both", inference_action_mode=None) + policy._checkpoint_action_mode = None + + with pytest.raises(ValueError, match="inference_action_mode.*explicitly"): + policy._resolve_inference_action_mode(None) + with pytest.raises(TypeError, match="unexpected keyword argument 'action_mode'"): + policy.predict_action_chunk({}, action_mode="continuous") + + +def test_rtc_generation_uses_previous_chunk_prefix(): + class DummyActionExpert(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + + def prepare_context(self, **kwargs): + del kwargs + return SimpleNamespace() + + def get_or_prepare_modulation_cache(self, timesteps, *, cache_key=None): + del cache_key + return [SimpleNamespace(conditioning=timestep) for timestep in timesteps] + + def forward_with_context(self, actions, timesteps, *, context, modulation=None): + del timesteps, context, modulation + return torch.ones_like(actions) * self.weight + + class DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + flow_matching_num_steps=2, + max_action_horizon=4, + max_action_dim=3, + ) + self.action_expert = DummyActionExpert() + self.batch_size = 1 + + def _require_action_expert(self): + return self.action_expert + + def forward(self, **kwargs): + self.batch_size = int(kwargs["input_ids"].shape[0]) + return SimpleNamespace(past_key_values=object()) + + def _extract_kv_states(self, past_key_values): + del past_key_values + kv = torch.zeros(self.batch_size, 1, 1) + return [(kv, kv)] + + def _get_encoder_attention_mask(self, input_ids, attention_mask): + del input_ids + return attention_mask + + def _depth_gate_from_condition(self, **kwargs): + del kwargs + return None, None + + def _apply_depth_gate_to_layer_kv_states(self, encoder_kv_states, depth_mask, depth_gate): + del depth_mask, depth_gate + return encoder_kv_states + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + mask_action_dim_padding=True, + rtc_config=RTCConfig(enabled=True, execution_horizon=2, max_guidance_weight=1.0), + ) + policy.rtc_processor = None + policy.model = torch.nn.Module() + policy.model.model = DummyBackbone() + policy.init_rtc_processor() + model_inputs = { + "input_ids": torch.ones(1, 2, dtype=torch.long), + "attention_mask": torch.ones(1, 2, dtype=torch.long), + } + action_dim_is_pad = torch.tensor([[False, False, False]]) + + without_prefix = policy._generate_actions_from_inputs_with_rtc( + model_inputs=model_inputs, + action_dim_is_pad=action_dim_is_pad, + num_steps=2, + generator=torch.Generator().manual_seed(0), + inference_delay=0, + prev_chunk_left_over=None, + execution_horizon=None, + ) + with_prefix = policy._generate_actions_from_inputs_with_rtc( + model_inputs=model_inputs, + action_dim_is_pad=action_dim_is_pad, + num_steps=2, + generator=torch.Generator().manual_seed(0), + inference_delay=0, + prev_chunk_left_over=torch.zeros(1, 4, 3), + execution_horizon=None, + ) + + assert without_prefix.shape == (1, 4, 3) + assert not torch.allclose(without_prefix, with_prefix) + + +def test_discrete_state_string_matches_molmoact2_bins(): + state = np.asarray([-1.0, 0.0, 1.0, np.nan, np.inf, -np.inf], dtype=np.float32) + + assert _build_discrete_state_string(state, 256) == ( + "" + ) + + +def test_question_normalization_matches_release_prompt_style(): + assert _normalize_question_text("Instruction: Pick up the cube, please!") == "pick up the cube, please" + assert ( + _normalize_question_text("The task is to open drawer. Then close it.") == "open drawer; then close it" + ) + + +def test_action_padding_marks_only_real_dimensions(): + step = object.__new__(MolmoAct2PackInputsProcessorStep) + step.max_action_dim = 32 + action = torch.ones(2, 3, 7) + + padded, horizon_is_pad, dim_is_pad = step._pad_action(action, None) + + assert padded.shape == (2, 3, 32) + assert torch.equal(padded[..., :7], action) + assert torch.count_nonzero(padded[..., 7:]) == 0 + assert not horizon_is_pad.any() + assert not dim_is_pad[:, :7].any() + assert dim_is_pad[:, 7:].all() + + +def test_action_dim_padding_loss_reduces_like_old_trainer(): + loss = torch.arange(2 * 2 * 3 * 4, dtype=torch.float32).reshape(2, 2, 3, 4) + action_dim_is_pad = torch.tensor( + [ + [False, False, True, True], + [False, True, True, True], + ] + ) + + reduced = MolmoAct2Policy._apply_action_dim_padding_mask(loss, action_dim_is_pad) + + expected = torch.stack( + [ + loss[0, :, :, :2].sum(dim=-1) / 2, + loss[1, :, :, :1].sum(dim=-1) / 1, + ], + dim=0, + ) + assert torch.equal(reduced, expected) + + +def test_action_chunk_padding_keeps_old_mean_denominator(): + loss = torch.ones(1, 2, 4, 3) + action_horizon_is_pad = torch.tensor([[False, False, True, True]]) + + masked = MolmoAct2Policy._apply_action_chunk_padding_mask(loss, action_horizon_is_pad) + + assert masked.mean().item() == 0.5 + + +def test_selected_discrete_loss_matches_full_causal_lm_loss(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=False, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="none", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(2, 4, 3) + labels = torch.tensor( + [ + [-100, 1, 2, -100], + [-100, -100, 3, 4], + ] + ) + + selected_loss, z_loss = policy._discrete_loss_from_backbone_outputs({"labels": labels}, outputs) + + logits = policy.model.lm_head(outputs.last_hidden_state) + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + expected_loss = F.cross_entropy(logits.float().view(-1, 5), shift_labels.view(-1), ignore_index=-100) + assert torch.allclose(selected_loss, expected_loss) + assert z_loss is None + + +def test_discrete_z_loss_matches_old_trainer_formula(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="none", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(2, 4, 3) + labels = torch.tensor( + [ + [-100, 1, 2, -100], + [-100, -100, 3, 4], + ] + ) + + ce_loss, z_loss = policy._discrete_loss_from_backbone_outputs({"labels": labels}, outputs) + + logits = policy.model.lm_head(outputs.last_hidden_state).float() + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + valid = shift_labels != -100 + expected_ce = F.cross_entropy(logits.view(-1, 5), shift_labels.view(-1), ignore_index=-100) + expected_z = 1e-4 * logits.logsumexp(dim=-1)[valid].pow(2).mean() + assert torch.allclose(ce_loss, expected_ce) + assert z_loss is not None + assert torch.allclose(z_loss, expected_z) + + +def test_discrete_reduction_none_preserves_mean_loss(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="root_subsegments_root_tokens", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(3, 5, 3) + labels = torch.tensor( + [ + [-100, 1, -100, -100, -100], + [-100, -100, 2, 3, -100], + [-100, 4, 3, 2, 1], + ] + ) + + ce_mean, z_mean = policy._discrete_loss_from_backbone_outputs( + {"labels": labels}, + outputs, + reduction="mean", + ) + ce_none, z_none = policy._discrete_loss_from_backbone_outputs( + {"labels": labels}, + outputs, + reduction="none", + ) + + assert ce_none.shape == (3,) + assert z_none is not None + assert z_none.shape == (3,) + assert torch.allclose(ce_none.mean(), ce_mean) + assert torch.allclose(z_none.mean(), z_mean) + + +def test_forward_reduction_none_returns_per_sample_discrete_loss(): + class DummyBackbone(torch.nn.Module): + def __init__(self, hidden_states): + super().__init__() + self.hidden_states = hidden_states + + def forward(self, **kwargs): + del kwargs + return SimpleNamespace(last_hidden_state=self.hidden_states) + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + action_mode="discrete", + inference_action_mode="discrete", + model_dtype="float32", + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="none", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + hidden_states = torch.randn(2, 4, 3) + policy._backbone = lambda: DummyBackbone(hidden_states) + batch = { + "input_ids": torch.ones(2, 4, dtype=torch.long), + "labels": torch.tensor( + [ + [-100, 1, 2, -100], + [-100, -100, 3, 4], + ] + ), + } + + loss_none, metrics_none = policy.forward(batch, reduction="none") + loss_mean, metrics_mean = policy.forward(batch, reduction="mean") + + assert loss_none.shape == (2,) + assert torch.allclose(loss_none.mean(), loss_mean) + assert metrics_none["loss"] == pytest.approx(metrics_mean["loss"]) + + +def test_discrete_root_token_weighting_matches_old_loss_mask_scaling(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="root_subsegments_root_tokens", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(2, 4, 3) + labels = torch.tensor( + [ + [-100, -100, 1, -100], + [-100, 2, 3, 4], + ] + ) + + ce_loss, z_loss = policy._discrete_loss_from_backbone_outputs({"labels": labels}, outputs) + + logits = policy.model.lm_head(outputs.last_hidden_state).float() + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + valid = shift_labels != -100 + log_z = logits.logsumexp(dim=-1) + token_ce = log_z - logits.gather(dim=-1, index=shift_labels.clamp_min(0).unsqueeze(-1)).squeeze(-1) + weights = torch.zeros_like(token_ce) + counts = valid.sum(dim=1).float() + weights[valid] = (2.0 / torch.sqrt(counts))[:, None].expand_as(weights)[valid] + expected_ce = (token_ce * weights).sum() / weights.sum() + expected_z = 1e-4 * (log_z.pow(2) * weights).sum() / weights.sum() + assert torch.allclose(ce_loss, expected_ce) + assert z_loss is not None + assert torch.allclose(z_loss, expected_z) + + +class _DummyActionTokenizer: + def decode(self, tokens, *, time_horizon=None, action_dim=None): + decoded = [] + for token_row in tokens: + decoded.append(np.full((time_horizon, action_dim), sum(token_row), dtype=np.float32)) + return np.stack(decoded) + + +def test_discrete_decode_extracts_action_bins_for_each_batch(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(chunk_size=2) + policy.action_tokenizer = _DummyActionTokenizer() + policy.model = torch.nn.Module() + policy.model.config = SimpleNamespace( + action_start_token_id=10, + action_end_token_id=11, + action_token_start_id=100, + num_action_tokens=4, + action_horizon=2, + ) + + actions = policy._decode_discrete_action_chunk( + torch.tensor( + [ + [10, 100, 101, 11, 2], + [10, 102, 103, 11, 2], + ] + ), + action_dim=2, + ) + + assert actions.shape == (2, 2, 2) + assert torch.equal(actions[0], torch.ones(2, 2)) + assert torch.equal(actions[1], torch.full((2, 2), 5.0)) + + +def test_discrete_predict_action_chunk_uses_hf_cached_generation_path(): + class DummyOutput: + def __init__(self, token_id, batch_size): + logits = torch.full((batch_size, 1, 128), -1e9) + logits[:, :, token_id] = 1.0 + self.logits = logits + self.past_key_values = object() + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + self.config = SimpleNamespace( + action_start_token_id=10, + action_end_token_id=11, + action_token_start_id=100, + num_action_tokens=4, + action_horizon=2, + ) + self.tokens = [10, 100, 101, 11, 2] + self.index = 0 + + def forward(self, **kwargs): + batch_size = int(kwargs["input_ids"].shape[0]) + return DummyOutput(self.tokens[self.index], batch_size) + + def _consume_generation_tokens(self, token_ids, *, past_key_values, attention_mask): + del past_key_values + self.index += 1 + if attention_mask is not None: + attention_mask = torch.cat([attention_mask, torch.ones_like(token_ids[:, None])], dim=-1) + return DummyOutput(self.tokens[self.index], int(token_ids.shape[0])), attention_mask + + def _require_eos_token_id(self): + return 2 + + def _action_token_id_to_bin(self): + return {100: 0, 101: 1, 102: 2, 103: 3} + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + action_mode="discrete", + inference_action_mode="discrete", + model_dtype="float32", + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))}, + discrete_generation_max_steps=None, + discrete_action_tokenizer="unused", + chunk_size=2, + n_action_steps=1, + rtc_config=None, + ) + policy._checkpoint_action_mode = None + policy.model = DummyModel() + policy.action_tokenizer = _DummyActionTokenizer() + + actions = policy.predict_action_chunk( + { + "input_ids": torch.ones(1, 3, dtype=torch.long), + "attention_mask": torch.ones(1, 3, dtype=torch.long), + } + ) + + assert policy.model.index == 4 + assert actions.shape == (1, 1, 2) + assert torch.equal(actions, torch.ones(1, 1, 2)) + + +def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled(): + class DummyOutput: + def __init__(self, token_id, past_key_values): + logits = torch.full((1, 1, 128), -1e9) + logits[:, :, token_id] = 1.0 + self.logits = logits + self.past_key_values = past_key_values + + class DummyLmHead(torch.nn.Module): + def forward(self, hidden_states): + token_id = int(hidden_states[0, 0, 0].item()) + logits = torch.full((1, 1, 128), -1e9) + logits[:, :, token_id] = 1.0 + return logits + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + self.lm_head = DummyLmHead() + self.config = SimpleNamespace( + action_start_token_id=10, + action_end_token_id=11, + action_token_start_id=100, + num_action_tokens=4, + action_horizon=2, + ) + self.tokens = [10, 100, 101, 11, 2] + self.index = 0 + self.used_static_cache = False + self.graph_steps = 0 + + def forward(self, **kwargs): + self.used_static_cache = kwargs.get("past_key_values") == "static-cache" + return DummyOutput(self.tokens[self.index], kwargs.get("past_key_values")) + + def _make_ar_decode_static_cache(self, inputs, *, max_steps): + assert int(inputs["input_ids"].shape[1]) == 3 + assert max_steps == 32 + return "static-cache" + + def _make_depth_decode_attention_bias(self, inputs, past_key_values): + assert past_key_values == "static-cache" + return torch.ones(1, 1, 35, 35, dtype=torch.float32) + + def _run_ar_decode_step(self, token_ids, *, past_key_values, attention_bias): + assert past_key_values == "static-cache" + assert attention_bias.shape == (1, 1, 35, 35) + self.index += 1 + self.graph_steps += 1 + return torch.tensor([[[float(self.tokens[self.index])]]]), past_key_values + + def _require_eos_token_id(self): + return 2 + + def _action_token_id_to_bin(self): + return {100: 0, 101: 1, 102: 2, 103: 3} + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + action_mode="discrete", + inference_action_mode="discrete", + model_dtype="float32", + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))}, + discrete_generation_max_steps=None, + discrete_action_tokenizer="unused", + chunk_size=2, + n_action_steps=1, + rtc_config=None, + enable_inference_cuda_graph=True, + ) + policy._checkpoint_action_mode = None + policy.model = DummyModel() + policy.action_tokenizer = _DummyActionTokenizer() + torch.nn.Module.train(policy, False) + + actions = policy.predict_action_chunk( + { + "input_ids": torch.ones(1, 3, dtype=torch.long), + "attention_mask": torch.ones(1, 3, dtype=torch.long), + } + ) + + assert policy.model.used_static_cache + assert policy.model.graph_steps == 4 + assert actions.shape == (1, 1, 2) + assert torch.equal(actions, torch.ones(1, 1, 2)) + + +class _DummyMolmoBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed = torch.nn.Embedding(5, 3) + + def get_input_embeddings(self): + return self.embed + + +class _DummyMolmoModel(torch.nn.Module): + def __init__(self, *, tie_lm_head: bool = False): + super().__init__() + self.model = _DummyMolmoBackbone() + self.lm_head = torch.nn.Linear(3, 5, bias=False) + if tie_lm_head: + self.lm_head.weight = self.model.embed.weight + + def get_input_embeddings(self): + return self.model.embed + + +def test_freeze_embedding_freezes_input_embeddings_only_when_untied(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.model = _DummyMolmoModel() + + policy._freeze_input_embeddings() + + assert not policy.model.model.embed.weight.requires_grad + assert policy.model.lm_head.weight.requires_grad + + +def test_freeze_embedding_rejects_tied_lm_head_without_mutating(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.model = _DummyMolmoModel(tie_lm_head=True) + + with pytest.raises(RuntimeError, match="would also freeze lm_head"): + policy._freeze_input_embeddings() + + assert policy.model.model.embed.weight.requires_grad diff --git a/uv.lock b/uv.lock index 3eb1dda23..eebbb7f95 100644 --- a/uv.lock +++ b/uv.lock @@ -2915,6 +2915,11 @@ metaworld = [ { name = "scipy" }, { name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] +molmoact2 = [ + { name = "peft" }, + { name = "scipy" }, + { name = "transformers" }, +] motorbridge-dep = [ { name = "motorbridge" }, ] @@ -3131,6 +3136,7 @@ requires-dist = [ { name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" }, { name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["molmoact2"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["motorbridge-dep"], marker = "extra == 'rebot'" }, { name = "lerobot", extras = ["motorbridge-smart-servo-dep"], marker = "extra == 'rebot'" }, { name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" }, @@ -3138,6 +3144,7 @@ requires-dist = [ { name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["peft"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" }, + { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'molmoact2'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["phone"], marker = "extra == 'all'" }, @@ -3165,6 +3172,7 @@ requires-dist = [ { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'aloha'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'libero'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'metaworld'" }, + { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'molmoact2'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'phone'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'pi'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" }, @@ -3176,6 +3184,7 @@ requires-dist = [ { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" }, + { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" }, @@ -3249,7 +3258,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt"