diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 412386e2d..8acaa4030 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,6 +59,8 @@ title: π₀-FAST (Pi0Fast) - local: pi05 title: π₀.₅ (Pi05) + - local: molmoact2 + title: MolmoAct2 - local: eo1 title: EO-1 - local: groot diff --git a/docs/source/molmoact2.mdx b/docs/source/molmoact2.mdx new file mode 100644 index 000000000..ddd178acd --- /dev/null +++ b/docs/source/molmoact2.mdx @@ -0,0 +1,433 @@ +# MolmoAct2 Policy + +MolmoAct2 is the LeRobot policy implementation of +[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot +training, evaluation, checkpointing, and dataset interfaces for easier use with +LeRobot datasets. + +This implementation currently supports training and evaluation for the regular +MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is +not included in this LeRobot policy yet and is coming soon. + +For the original MolmoAct2 training code used for the experiments reported in +the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2). + +## Installation Requirements + +Install LeRobot with the MolmoAct2 optional dependencies: + +```bash +pip install -e ".[molmoact2]" +``` + +To run the models in this repository, you need an NVIDIA GPU. The measurements +below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7 +padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use +`gradient_checkpointing=true` and include the forward pass, backward pass, +gradient clipping, optimizer step, and optimizer state allocation. Values are +peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for +dataloader workers, CUDA context, and fragmentation. + +Multi-GPU training through `accelerate` increases throughput and global batch +size, but this LeRobot port does not currently expose the original MolmoAct2 +`fsdp_devices` model-parallel training path. The current training script has +not been tested for multi-node training. + +| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 | +| ------------------------------------------------ | ----------------: | -----------------: | -----------------: | +| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - | +| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB | +| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB | +| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB | + +The repo has been tested with Ubuntu 22.04. + +## Usage + +To use MolmoAct2 in a LeRobot training config, set: + +```python +policy.type=molmoact2 +``` + +## Training + +MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face +checkpoint format or from a checkpoint already saved by LeRobot. Both routes use +the same LeRobot training loop, dataset transforms, checkpoint saving, and +logging. The difference is only how the initial policy weights and processor +state are loaded. + +### Training With Original MolmoAct2 Weight + +Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint, +for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load +the original HF model files, then build its own policy processor from the +dataset metadata and the policy options below. + +The command below shows full fine-tuning on the merged LIBERO dataset. It uses +bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image +augmentation, and LeRobot's checkpointing/logging path. + +```bash +accelerate launch \ + --num_processes=8 \ + --mixed_precision=bf16 \ + -m lerobot.scripts.lerobot_train \ + --dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.video_backend=pyav \ + --dataset.image_transforms.enable=true \ + --policy.type=molmoact2 \ + --policy.checkpoint_path=allenai/MolmoAct2-LIBERO \ + --policy.device=cuda \ + --policy.action_mode=both \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.setup_type="single franka robotic arm in libero" \ + --policy.control_mode="delta end-effector pose" \ + --policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \ + --policy.model_dtype=bfloat16 \ + --policy.num_flow_timesteps=8 \ + --policy.gradient_checkpointing=true \ + --policy.freeze_embedding=true \ + --policy.normalize_gripper=false \ + --policy.enable_knowledge_insulation=false \ + --policy.push_to_hub=false \ + --wandb.enable=true \ + --wandb.entity= \ + --wandb.project= \ + --job_name= \ + --output_dir=outputs/ \ + --steps=10000 \ + --batch_size=32 \ + --num_workers=4 \ + --log_freq=20 \ + --eval_freq=-1 \ + --save_checkpoint=true \ + --save_freq=2000 +``` + +### Training With LeRobot MolmoAct2 Weight + +Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by +LeRobot, either from a local `pretrained_model` directory or from the Hub. This +restores the saved LeRobot policy config, model weights, processor, and +normalization statistics. You can still override training-time options such as +`batch_size`, `steps`, LoRA flags, or `policy.action_mode`. + +```bash +accelerate launch \ + --num_processes=8 \ + --mixed_precision=bf16 \ + -m lerobot.scripts.lerobot_train \ + --dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \ + --dataset.video_backend=pyav \ + --dataset.image_transforms.enable=true \ + --policy.path=/path/to/pretrained_model \ + --policy.device=cuda \ + --policy.action_mode=both \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.model_dtype=bfloat16 \ + --policy.num_flow_timesteps=8 \ + --policy.gradient_checkpointing=true \ + --wandb.enable=true \ + --wandb.entity= \ + --wandb.project= \ + --job_name= \ + --output_dir=outputs/ \ + --steps=10000 \ + --batch_size=32 \ + --num_workers=4 \ + --log_freq=20 \ + --eval_freq=-1 \ + --save_checkpoint=true \ + --save_freq=2000 +``` + +### Common Practices + +For fine-tuning on a comparatively small dataset, such as a single LIBERO suite +or a real-world dataset with less than 200 demonstrations, a global batch size of +16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both +cases, we intentionally keep the action expert fully trainable, which we found +to be crucial for model performance. For larger fine-tuning datasets, larger +global batch sizes and full fine-tuning are usually preferred. + +### Common Policy Options + +- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from. + Use this for released MolmoAct2 weights. +- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints + created by LeRobot training. +- `policy.action_mode`: training target, one of `continuous`, `discrete`, or + `both`. `both` trains the flow-matching action expert and the discrete + action-token loss. +- `policy.train_action_expert_only`: trains only parameters whose names contain + `action_expert`. It requires `policy.action_mode=continuous`. +- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use + `policy.enable_lora_action_expert=true` only if LoRA should also cover action + expert linear layers. When `policy.enable_lora_action_expert=false`, the + action expert base weights remain fully trainable while the VLM is trained + through LoRA adapters. When `policy.enable_lora_action_expert=true`, the + action expert is also adapter-tuned instead of fully fine-tuned. +- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert + context K/V states before the action loss. The default is `false`. +- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use + `10`. This LeRobot port overrides the loaded checkpoint's + `max_action_horizon` with this value. +- `policy.n_action_steps`: number of actions consumed from each predicted + chunk before querying the policy again. For LIBERO, set it to `chunk_size`. +- `policy.setup_type`: text inserted into the prompt to describe the robot and + scene, e.g. `single franka robotic arm in libero`. More examples are listed + in the `metadata_by_tag` entries of + [`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json). +- `policy.control_mode`: text inserted into the prompt to describe the action + space, e.g. `delta end-effector pose` or `absolute joint pose`. +- `policy.image_keys`: ordered LeRobot image observation keys passed to the + processor. +- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`, + `bfloat16`, or `float16`. Use `bfloat16` for normal training. +- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per + example during training. We use `8` for fine-tuning. +- `policy.num_inference_steps`: optional override for continuous action + generation steps at inference time. +- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path + to reduce activation memory. +- `policy.freeze_embedding`: freezes input embeddings. The default is `true`. +- `policy.normalize_gripper`: controls whether gripper dimensions are included + in state/action quantile normalization. The default is `false`. +- `policy.normalize_language`: normalizes task strings before prompt + construction. The default is `true`. +- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss. + Released checkpoints use `policy.expected_max_action_dim=32`. +- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to + infer it from images, state dimension, action dimension, action horizon, and + discrete-action mode. + +### Learning Rates + +MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2 +fine-tuning experiments. + +- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM, + `policy.optimizer_vit_lr=5e-6` for the vision tower, + `policy.optimizer_connector_lr=5e-6` for image connector layers, and + `policy.optimizer_action_expert_lr=5e-5` for the action expert. +- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter + groups to `5e-5` when `policy.enable_lora_vlm=true`. By default, + `policy.enable_lora_action_expert=false`, so the action expert is still fully + fine-tuned with `policy.optimizer_action_expert_lr`. If + `policy.enable_lora_action_expert=true`, the action expert is trained through + LoRA adapters instead. +- Action-expert-only fine-tuning trains only the action expert and uses + `policy.optimizer_action_expert_lr=5e-5`. + +You can override the full fine-tuning and action-expert learning rates with +`policy.optimizer_lr`, `policy.optimizer_vit_lr`, +`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`. +Scheduler settings can be changed with `policy.scheduler_warmup_steps`, +`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`. + +### Dataset Quantile Statistics + +MolmoAct2 defaults to quantile normalization for state and action features. If +your dataset has not been converted with quantile statistics, you can add them +with: + +```bash +python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ + --repo-id=your_dataset +``` + +Alternatively, train MolmoAct2 with mean/std normalization: + +```bash +--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' +``` + +## Evaluation + +Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2 +HF checkpoints. For LIBERO replication, keep the EGL rendering environment +fixed and use `policy.per_episode_seed=true`. + +**Important:** We found that `num_steps_wait=10` does not reliably let the +LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation +results reported here use `num_steps_wait=50`. + +### Evaluation With LeRobot MolmoAct2 Weight + +Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and +normalization statistics are restored together with the model. + +```bash +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +lerobot-eval \ + --policy.path=allenai/MolmoAct2-LIBERO-LeRobot \ + --policy.inference_action_mode=continuous \ + --policy.model_dtype=bfloat16 \ + --policy.use_amp=true \ + --policy.enable_inference_cuda_graph=true \ + --policy.device=cuda \ + --policy.per_episode_seed=true \ + --policy.eval_seed=1000 \ + --env.type=libero \ + --env.task=libero_10,libero_goal,libero_object,libero_spatial \ + --env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \ + --eval.batch_size=1 \ + --eval.n_episodes=50 \ + --seed=1000 +``` + +### Evaluation With Original MolmoAct2 Weight + +You can evaluate a released Hugging Face checkpoint directly without first +converting it to a LeRobot checkpoint. In this case, set +`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`. +For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state +normalization statistics, action horizon, prompt metadata, and image-key order +from the checkpoint's `norm_stats.json`. + +To fully replicate the MolmoAct2 paper results with released Hugging Face +checkpoints, we recommend using the v0.5.1-pinned +[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference) +branch. That branch matches the original evaluation settings used for the +reported numbers. + +```bash +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +lerobot-eval \ + --policy.type=molmoact2 \ + --policy.checkpoint_path=allenai/MolmoAct2-LIBERO \ + --policy.norm_tag=libero \ + --policy.inference_action_mode=continuous \ + --policy.model_dtype=float32 \ + --policy.use_amp=false \ + --policy.enable_inference_cuda_graph=true \ + --policy.device=cuda \ + --policy.per_episode_seed=true \ + --policy.eval_seed=1000 \ + --env.type=libero \ + --env.task=libero_goal \ + --env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \ + --eval.batch_size=1 \ + --eval.n_episodes=50 \ + --seed=1000 +``` + +Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the +full LIBERO suite. The same command works for other released MolmoAct2 +checkpoints as long as the requested `policy.norm_tag` exists in that +checkpoint's `norm_stats.json`. + +### Common Evaluation Options + +- `policy.inference_action_mode`: required for rollout. Use `continuous` for + flow-matching inference or `discrete` for action-token inference. It must be + compatible with the training-time `policy.action_mode` saved in the + checkpoint. +- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints + saved by LeRobot. +- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo. + Use this with `policy.type=molmoact2` and `policy.norm_tag`. +- `policy.norm_tag`: selects normalization statistics, prompt metadata, + image-key order, and action horizon from the original checkpoint's + `norm_stats.json`. It is required for direct original-HF checkpoint + evaluation. +- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal + GPU evaluation. Use `float32` only when you explicitly want fp32 inference. +- `policy.use_amp`: runs the policy forward under autocast during eval. For + `model_dtype=bfloat16`, keep this enabled. +- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA + graph path for faster repeated continuous-action rollout. +- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous + action generation deterministic per episode for replication. +- `env.task`: comma-separated LIBERO suites or a single suite. Use + `libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark. +- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected + by the policy processor. + +## Performance Results + +### LIBERO Benchmark Results + +MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To +compare and test its LeRobot implementation, we fine-tuned +[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO) +for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on +8 H100 GPUs, then compared the results to the original MolmoAct2 reference +results. + +The LeRobot fine-tuned checkpoint reported here is available at +[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot) +and was trained on +[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset). + +| Benchmark | LeRobot Implementation | MolmoAct2 Original | +| -------------- | ---------------------: | -----------------: | +| LIBERO Spatial | 98.4% | 97.8% | +| LIBERO Object | 100.0% | 100.0% | +| LIBERO Goal | 98.0% | 97.8% | +| LIBERO 10 | 96.6% | 93.2% | +| Average | 98.25% | 97.20% | + +These results demonstrate MolmoAct2's strong performance across diverse robotic +manipulation tasks. To reproduce them, follow the instructions in the LIBERO +evaluation section. + +## Differences From the Original Implementation + +This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's +dataset, training, evaluation, checkpoint, and logging infrastructure. The main +differences from the original training repository are: + +- The original paper training stack loads the model in fp32 and trains under + mixed precision. This LeRobot port usually loads the checkpoint directly in + `policy.model_dtype=bfloat16` for lower memory use. +- The original repository uses its own FSDP/model-parallel training path. The + LeRobot port uses the standard LeRobot/Accelerate training path and has not + been tested for multi-node training. +- The original repository supports sequence packing. The LeRobot port trains on + one LeRobot sample per item and pads to an inferred fixed sequence budget. +- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving, + dataset transforms, image augmentation, and Weights & Biases logging + conventions. +- The original training path supports mixed action horizons by padding to + `max_action_horizon` and masking padded horizon slots in the action expert + self-attention. This is useful when training across datasets with different + control frequencies. The LeRobot port currently targets single-dataset + fine-tuning, so `policy.chunk_size` overrides the checkpoint + `max_action_horizon` and horizon masking is not implemented yet. Support for + this mixed-horizon path is planned. + +## Citation + +```bibtex +@misc{fang2026molmoact2actionreasoningmodels, + title={MolmoAct2: Action Reasoning Models for Real-world Deployment}, + author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna}, + year={2026}, + eprint={2605.02881}, + archivePrefix={arXiv}, + primaryClass={cs.RO}, + url={https://arxiv.org/abs/2605.02881}, +} +``` + +## License + +This model is licensed under Apache 2.0. It is intended for research and +educational use in accordance with +[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use), +consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2). diff --git a/docs/source/policy_molmoact2_README.md b/docs/source/policy_molmoact2_README.md new file mode 100644 index 000000000..df3a6341e --- /dev/null +++ b/docs/source/policy_molmoact2_README.md @@ -0,0 +1,39 @@ +# MolmoAct2 + +This repository contains the LeRobot policy implementation of +[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for +training, evaluation, checkpointing, and dataset compatibility. + +This implementation currently supports training and evaluation for the regular +MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is +not included in this LeRobot policy yet and is coming soon. + +For the original MolmoAct2 training code used for the experiments reported in +the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2). + +## LIBERO Evaluation + +Important: we found that `num_steps_wait=10` does not reliably let the LIBERO +scene stabilize and can degrade measured success. All LIBERO evaluation results +reported for this LeRobot implementation use `num_steps_wait=50`. + +## Citation + +```bibtex +@misc{fang2026molmoact2actionreasoningmodels, + title={MolmoAct2: Action Reasoning Models for Real-world Deployment}, + author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna}, + year={2026}, + eprint={2605.02881}, + archivePrefix={arXiv}, + primaryClass={cs.RO}, + url={https://arxiv.org/abs/2605.02881}, +} +``` + +## License + +This model is licensed under Apache 2.0. It is intended for research and +educational use in accordance with +[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use), +consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2). diff --git a/pyproject.toml b/pyproject.toml index 5d182648c..bd76586ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,7 @@ wallx = [ "lerobot[qwen-vl-utils-dep]", ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] +molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"] multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"] groot = [ @@ -274,6 +275,7 @@ all = [ "lerobot[multi_task_dit]", "lerobot[wallx]", "lerobot[pi]", + "lerobot[molmoact2]", "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[xvla]", @@ -404,7 +406,8 @@ default.extend-ignore-identifiers-re = [ "thw", "inpt", "ROBOTIS", - "OT_VALUE" + "OT_VALUE", + "VanderBilt" ] # TODO: Uncomment when ready to use diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 3a6b8e5d2..68d23c9ca 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig from .groot.configuration_groot import GrootConfig as GrootConfig +from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig @@ -43,6 +44,7 @@ __all__ = [ "EO1Config", "GaussianActorConfig", "GrootConfig", + "MolmoAct2Config", "MultiTaskDiTConfig", "PI0Config", "PI0FastConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 8937bc6ae..64210cf01 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig from .groot.configuration_groot import GrootConfig +from .molmoact2.configuration_molmoact2 import MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config from .pi05.configuration_pi05 import PI05Config @@ -88,7 +89,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x". + "multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x", + "molmoact2". Returns: The policy class corresponding to the given name. @@ -151,6 +153,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .eo1.modeling_eo1 import EO1Policy return EO1Policy + elif name == "molmoact2": + from .molmoact2.modeling_molmoact2 import MolmoAct2Policy + + return MolmoAct2Policy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -168,7 +174,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor", - "smolvla", "wall_x". + "smolvla", "wall_x", "molmoact2". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -203,6 +209,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return WallXConfig(**kwargs) elif policy_type == "eo1": return EO1Config(**kwargs) + elif policy_type == "molmoact2": + return MolmoAct2Config(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -231,6 +239,7 @@ class ProcessorConfigKwargs(TypedDict, total=False): preprocessor_overrides: dict[str, Any] | None postprocessor_overrides: dict[str, Any] | None dataset_stats: dict[str, dict[str, torch.Tensor]] | None + dataset_meta: Any | None def make_pre_post_processors( @@ -285,6 +294,25 @@ def make_pre_post_processors( kwargs["preprocessor_overrides"] = preprocessor_overrides kwargs["postprocessor_overrides"] = postprocessor_overrides + if isinstance(policy_cfg, MolmoAct2Config): + from .molmoact2 import processor_molmoact2 # noqa: F401 + + preprocessor_overrides = dict(kwargs.get("preprocessor_overrides", {})) + if "normalizer_processor" in preprocessor_overrides: + preprocessor_overrides.setdefault( + "molmoact2_masked_normalizer", + preprocessor_overrides.pop("normalizer_processor"), + ) + kwargs["preprocessor_overrides"] = preprocessor_overrides + + postprocessor_overrides = dict(kwargs.get("postprocessor_overrides", {})) + if "unnormalizer_processor" in postprocessor_overrides: + postprocessor_overrides.setdefault( + "molmoact2_masked_unnormalizer", + postprocessor_overrides.pop("unnormalizer_processor"), + ) + kwargs["postprocessor_overrides"] = postprocessor_overrides + preprocessor = PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, config_filename=kwargs.get( @@ -414,6 +442,15 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, MolmoAct2Config): + from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors + + processors = make_molmoact2_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), + ) + else: try: processors = _make_processors_from_policy_config( @@ -499,6 +536,10 @@ def make_policy( action_names = ds_meta.features.get(ACTION, {}).get("names") if action_names is not None: cfg.action_feature_names = list(action_names) + if ds_meta is not None: + set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None) + if callable(set_dataset_feature_metadata): + set_dataset_feature_metadata(ds_meta.features) kwargs["config"] = cfg diff --git a/src/lerobot/policies/molmoact2/README.md b/src/lerobot/policies/molmoact2/README.md new file mode 120000 index 000000000..ef419516d --- /dev/null +++ b/src/lerobot/policies/molmoact2/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_molmoact2_README.md \ No newline at end of file diff --git a/src/lerobot/policies/molmoact2/__init__.py b/src/lerobot/policies/molmoact2/__init__.py new file mode 100644 index 000000000..d1f02fe8a --- /dev/null +++ b/src/lerobot/policies/molmoact2/__init__.py @@ -0,0 +1,15 @@ +from .configuration_molmoact2 import MolmoAct2Config + +__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"] + + +def __getattr__(name): + if name == "MolmoAct2Policy": + from .modeling_molmoact2 import MolmoAct2Policy + + return MolmoAct2Policy + if name == "make_molmoact2_pre_post_processors": + from .processor_molmoact2 import make_molmoact2_pre_post_processors + + return make_molmoact2_pre_post_processors + raise AttributeError(name) diff --git a/src/lerobot/policies/molmoact2/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/configuration_molmoact2.py new file mode 100644 index 000000000..da21774a6 --- /dev/null +++ b/src/lerobot/policies/molmoact2/configuration_molmoact2.py @@ -0,0 +1,324 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Any + +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import ( + AdamWConfig, + CosineDecayWithWarmupSchedulerConfig, + LRSchedulerConfig, + OptimizerConfig, +) +from lerobot.utils.constants import ACTION, OBS_STATE + +from ..rtc.configuration_rtc import RTCConfig + +MOLMOACT2_DEFAULT_NUM_IMAGES = 2 +MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196 +MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80 +MOLMOACT2_TASK_TOKEN_BUDGET = 32 +MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32 +MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64 +MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4 +MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6 +MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95 + + +@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup") +@dataclass +class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig): + """MolmoAct2-local cosine scheduler with optional decay-step auto-match. + + LeRobot's generic cosine scheduler keeps an explicit integer decay length. + For MolmoAct2, leaving num_decay_steps unset means "decay across this run's + training steps"; build() is the first point where num_training_steps is known. + """ + + num_decay_steps: int | None + + def build(self, optimizer, num_training_steps: int): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.peak_lr, + decay_lr=self.decay_lr, + num_warmup_steps=self.num_warmup_steps, + num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps, + ).build(optimizer, num_training_steps=num_training_steps) + + +def _round_up(value: int, multiple: int) -> int: + return int(math.ceil(value / multiple) * multiple) + + +def infer_molmoact2_max_sequence_length( + *, + num_images: int, + state_dim: int, + action_dim: int, + action_horizon: int, + include_discrete_action: bool, +) -> int: + """Infer the padded text/image sequence cap from MolmoAct2's fixed token layout.""" + if num_images < 1: + num_images = MOLMOACT2_DEFAULT_NUM_IMAGES + if state_dim < 0: + state_dim = 0 + if action_dim < 1: + action_dim = 1 + if action_horizon < 1: + action_horizon = 1 + + image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE + prompt_tokens = ( + MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET + + MOLMOACT2_TASK_TOKEN_BUDGET + + state_dim + + MOLMOACT2_SEQUENCE_LENGTH_MARGIN + ) + action_tokens = 0 + if include_discrete_action: + action_tokens_per_step = max( + MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP, + math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM), + ) + action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step + + return _round_up( + image_tokens + prompt_tokens + action_tokens, + MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE, + ) + + +@PreTrainedConfig.register_subclass("molmoact2") +@dataclass +class MolmoAct2Config(PreTrainedConfig): + """MolmoAct2 policy backed by the converted HF checkpoint implementation.""" + + checkpoint_path: str = "allenai/MolmoAct2" + checkpoint_revision: str | None = None + checkpoint_force_download: bool = False + trust_remote_code: bool = True + + n_obs_steps: int = 1 + chunk_size: int = 30 + n_action_steps: int = 30 + + action_mode: str = "both" + inference_action_mode: str | None = None + discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer" + discrete_generation_max_steps: int | None = None + norm_tag: str | None = None + + setup_type: str = "" + control_mode: str = "" + image_keys: list[str] = field(default_factory=list) + normalize_language: bool = True + add_setup_tokens: bool = True + add_control_tokens: bool = True + normalize_gripper: bool = False + num_state_tokens: int = 256 + # Leave unset for the default MolmoAct2 sequence budget inferred from the fixed + # image/prompt/state/action token layout. Override only for unusual long prompts. + max_sequence_length: int | None = None + + # Fixed by released MolmoAct2 checkpoints. We validate this at model load. + expected_max_action_dim: int = 32 + + # Flow-matching training knobs copied from the original MolmoAct2 training path. + num_flow_timesteps: int = 8 + flow_matching_cutoff: float = 1.0 + flow_matching_time_offset: float = 0.001 + flow_matching_time_scale: float = 0.999 + flow_matching_beta_alpha: float = 1.0 + flow_matching_beta_beta: float = 1.5 + num_inference_steps: int | None = None + mask_action_dim_padding: bool = True + enable_inference_cuda_graph: bool = True + # MolmoAct2-local eval option. When enabled, stochastic continuous action + # generation uses a rollout-local generator derived from eval_seed. + per_episode_seed: bool = False + eval_seed: int | None = None + rtc_config: RTCConfig | None = None + + # Default is full finetuning with gradients from the action expert flowing into the VLM. + enable_lora_vlm: bool = False + lora_rank: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_bias: str = "none" + enable_lora_action_expert: bool = False + enable_knowledge_insulation: bool = False + freeze_embedding: bool = True + train_action_expert_only: bool = False + gradient_checkpointing: bool = False + + model_dtype: str = "bfloat16" + softmax_auxiliary_loss: bool = True + softmax_auxiliary_loss_scale: float = 1e-4 + discrete_loss_token_weighting: str = "root_subsegments_root_tokens" + + optimizer_lr: float = 1e-5 + optimizer_vit_lr: float = 5e-6 + optimizer_connector_lr: float = 5e-6 + optimizer_action_expert_lr: float = 5e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-6 + optimizer_weight_decay: float = 0.0 + optimizer_grad_clip_norm: float = 1.0 + + scheduler_warmup_steps: int = 200 + scheduler_decay_steps: int | None = None + scheduler_decay_lr: float = 1e-6 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.QUANTILES, + "ACTION": NormalizationMode.QUANTILES, + } + ) + + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + dataset_feature_names: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + super().__post_init__() + if self.action_mode not in {"continuous", "discrete", "both"}: + raise ValueError( + f"Unsupported action_mode={self.action_mode!r}. " + "Expected one of {'continuous', 'discrete', 'both'}." + ) + if self.inference_action_mode not in {None, "continuous", "discrete"}: + raise ValueError( + f"Unsupported inference_action_mode={self.inference_action_mode!r}. " + "Expected one of {None, 'continuous', 'discrete'}." + ) + if self.inference_action_mode == "continuous" and self.action_mode == "discrete": + raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.") + if self.inference_action_mode == "discrete" and self.action_mode == "continuous": + raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.") + if self.train_action_expert_only and self.action_mode != "continuous": + raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.") + if self.train_action_expert_only and self.enable_lora_vlm: + raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.") + if self.enable_lora_action_expert and not self.enable_lora_vlm: + raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.") + if self.chunk_size < 1: + raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.") + if self.n_action_steps < 1: + raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.") + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})." + ) + if self.expected_max_action_dim != 32: + raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.") + if self.model_dtype not in {"float32", "bfloat16", "float16"}: + raise ValueError( + f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'." + ) + if self.lora_rank < 1: + raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.") + if self.lora_alpha < 1: + raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.") + if not 0 <= self.lora_dropout <= 1: + raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.") + if self.lora_bias not in {"none", "all", "lora_only"}: + raise ValueError( + f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'." + ) + if self.discrete_loss_token_weighting not in { + "none", + "token", + "root_tokens", + "root_subsegments", + "root_subsegments_root_tokens", + }: + raise ValueError( + f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}." + ) + if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1: + raise ValueError( + f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}." + ) + if self.max_sequence_length is not None and self.max_sequence_length < 1: + raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.") + + def inferred_max_sequence_length( + self, + *, + num_images: int | None = None, + state_dim: int | None = None, + action_dim: int | None = None, + action_horizon: int | None = None, + include_discrete_action: bool | None = None, + ) -> int: + if self.max_sequence_length is not None: + return int(self.max_sequence_length) + + if num_images is None: + num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES + if state_dim is None: + state_feature = self.robot_state_feature + state_dim = int(state_feature.shape[0]) if state_feature is not None else 0 + if action_dim is None: + action_feature = self.action_feature + action_dim = ( + int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim + ) + if action_horizon is None: + action_horizon = self.chunk_size + if include_discrete_action is None: + include_discrete_action = self.action_mode in {"discrete", "both"} + + return infer_molmoact2_max_sequence_length( + num_images=int(num_images), + state_dim=int(state_dim), + action_dim=int(action_dim), + action_horizon=int(action_horizon), + include_discrete_action=bool(include_discrete_action), + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None + + def get_optimizer_preset(self) -> OptimizerConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return MolmoAct2CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None: + self.dataset_feature_names = {} + for key in (ACTION, OBS_STATE): + feature = features.get(key) if isinstance(features, dict) else None + if isinstance(feature, dict) and feature.get("names") is not None: + self.dataset_feature_names[key] = feature["names"] + + def validate_features(self) -> None: + if OBS_STATE not in self.input_features: + self.input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(0,)) + if ACTION not in self.output_features: + self.output_features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(0,)) diff --git a/src/lerobot/policies/molmoact2/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/modeling_molmoact2.py new file mode 100644 index 000000000..97bfef572 --- /dev/null +++ b/src/lerobot/policies/molmoact2/modeling_molmoact2.py @@ -0,0 +1,2108 @@ +from __future__ import annotations + +import json +import os +import types +from collections import defaultdict, deque +from contextlib import nullcontext, suppress +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torch.utils.checkpoint +from huggingface_hub import snapshot_download +from torch import Tensor +from torch.distributions import Beta + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION +from lerobot.utils.import_utils import require_package + +from ..rtc.modeling_rtc import RTCProcessor +from .configuration_molmoact2 import MolmoAct2Config + +_MODEL_INPUT_KEYS = { + "input_ids", + "pixel_values", + "image_token_pooling", + "image_grids", + "image_num_crops", + "pixel_values_videos", + "video_token_pooling", + "video_grids", + "attention_mask", + "position_ids", + "past_key_values", + "token_type_ids", + "inputs_embeds", +} + + +def _hf_token() -> str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_checkpoint_location( + checkpoint_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + checkpoint_path = str(checkpoint_path or "").strip() + if not checkpoint_path: + raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") + local_path = Path(checkpoint_path).expanduser() + if local_path.exists(): + return str(local_path) + return snapshot_download( + repo_id=checkpoint_path, + repo_type="model", + revision=revision, + force_download=force_download, + token=_hf_token(), + ) + + +def _load_hf_norm_metadata_for_tag( + checkpoint_path: str, + *, + revision: str | None, + force_download: bool, + norm_tag: str | None, +) -> dict[str, Any]: + norm_tag = str(norm_tag or "").strip() + if not norm_tag: + return {} + checkpoint_location = Path( + _resolve_checkpoint_location( + checkpoint_path, + revision=revision, + force_download=force_download, + ) + ) + norm_stats_filename = "norm_stats.json" + config_path = checkpoint_location / "config.json" + if config_path.exists(): + with suppress(OSError, json.JSONDecodeError): + norm_stats_filename = str( + json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename + ) + stats_path = checkpoint_location / norm_stats_filename + if not stats_path.exists(): + raise FileNotFoundError( + f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." + ) + payload = json.loads(stats_path.read_text()) + metadata_by_tag = payload.get("metadata_by_tag") + if not isinstance(metadata_by_tag, dict): + raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") + metadata = metadata_by_tag.get(norm_tag) + if not isinstance(metadata, dict): + available = sorted(str(tag) for tag in metadata_by_tag) + raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") + return metadata + + +def _torch_dtype(dtype: str) -> torch.dtype: + if dtype == "float32": + return torch.float32 + if dtype == "bfloat16": + return torch.bfloat16 + if dtype == "float16": + return torch.float16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +def _sample_beta_timesteps( + *, + batch_size: int, + device: torch.device, + cutoff: float, + time_offset: float, + time_scale: float, + alpha: float, + beta: float, +) -> Tensor: + if cutoff < time_offset: + raise ValueError(f"flow-matching cutoff must be >= time_offset, got {cutoff} < {time_offset}") + if time_scale <= 0: + raise ValueError(f"flow-matching time_scale must be > 0, got {time_scale}") + upper = min(cutoff, time_offset + time_scale) + dist = Beta(torch.tensor(alpha, device=device), torch.tensor(beta, device=device)) + samples = dist.sample((batch_size,)) + scale = upper - time_offset + if scale == 0: + return torch.full((batch_size,), time_offset, device=device, dtype=samples.dtype) + return time_offset + scale * samples + + +def _patch_batched_image_attention_bias(backbone: Any) -> None: + original = getattr(backbone, "_build_native_attention_bias", None) + if original is None: + return + original_func = getattr(original, "__func__", original) + original_globals = getattr(original_func, "__globals__", {}) + cache_seq_len = original_globals.get("_cache_seq_len_int") + cache_max_len = original_globals.get("_cache_max_len_int") + if cache_seq_len is None or cache_max_len is None: + return + + def _build_native_attention_bias( + self, + *, + inputs_embeds: Tensor, + attention_mask: Tensor | None, + token_type_ids: Tensor | None, + past_key_values: Any, + ) -> Tensor: + if attention_mask is not None and attention_mask.ndim == 4: + return attention_mask.to(device=inputs_embeds.device) + batch_size, seq_len = inputs_embeds.shape[:2] + past_length = int(cache_seq_len(past_key_values)) + current_length = past_length + int(seq_len) + max_cache_len = int(cache_max_len(past_key_values)) + attention_mask_len = max_cache_len if max_cache_len > 0 else current_length + device = inputs_embeds.device + + if attention_mask is None: + positions = torch.arange(attention_mask_len, device=device) + valid_mask = positions.unsqueeze(0) < current_length + valid_mask = valid_mask.expand(batch_size, -1) + elif attention_mask.ndim == 2: + valid_mask = torch.zeros((batch_size, attention_mask_len), device=device, dtype=torch.bool) + source_mask = attention_mask.to(device=device, dtype=torch.bool) + copy_len = min(int(source_mask.shape[-1]), attention_mask_len) + if copy_len > 0: + valid_mask[:, :copy_len] = source_mask[:, :copy_len] + if attention_mask_len > current_length: + valid_mask[:, current_length:] = False + else: + raise ValueError(f"Unsupported attention_mask shape for MolmoAct2: {tuple(attention_mask.shape)}") + + valid_mask = valid_mask[:, None, None, :] + causal_mask = torch.tril( + torch.ones(attention_mask_len, attention_mask_len, device=device, dtype=torch.bool) + )[None, None, past_length:current_length, :attention_mask_len] + + if token_type_ids is not None and past_length == 0: + causal_mask = causal_mask.expand(batch_size, -1, -1, -1).clone() + image_mask = token_type_ids.to(device=device, dtype=torch.bool) + can_attend_back = image_mask[:, :, None] & image_mask[:, None, :] + image_len = min(int(token_type_ids.shape[1]), attention_mask_len) + causal_mask[:, :, :, :image_len] = ( + causal_mask[:, :, :, :image_len] | can_attend_back[:, None, :, :image_len] + ) + + allowed = valid_mask & causal_mask + return torch.where( + allowed, + torch.zeros((), device=device, dtype=inputs_embeds.dtype), + torch.full((), torch.finfo(inputs_embeds.dtype).min, device=device, dtype=inputs_embeds.dtype), + ) + + backbone._build_native_attention_bias = types.MethodType(_build_native_attention_bias, backbone) + + +def _patch_leaf_safe_input_embedding_update(backbone: Any) -> None: + if getattr(backbone, "_lerobot_leaf_safe_input_embedding_update_patched", False): + return + if not callable(getattr(backbone, "build_input_embeddings", None)): + return + + def _build_input_embeddings( + self, + input_ids: Tensor, + images: Tensor | None = None, + token_pooling: Tensor | None = None, + ) -> tuple[Tensor, Tensor | None]: + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + x = self.transformer.wte(input_ids) + + image_features = None + if images is not None: + image_features = self.vision_backbone(images, token_pooling).to(x.device) + is_image_patch = input_ids.reshape(-1) == self.config.image_patch_id + if is_image_patch.sum() != len(image_features): + raise RuntimeError( + f"Expected {int(is_image_patch.sum())} image patch embeddings, got {len(image_features)}." + ) + flat_x = x.reshape(-1, x.shape[-1]).clone() + flat_x[is_image_patch] = flat_x[is_image_patch] + image_features + x = flat_x.reshape_as(x) + + x = self.transformer.emb_drop(x) + return x, image_features + + backbone.build_input_embeddings = types.MethodType(_build_input_embeddings, backbone) + backbone._lerobot_leaf_safe_input_embedding_update_patched = True + + +def _patch_memory_efficient_vision_backbone(backbone: Any, *, gradient_checkpointing: bool) -> None: + vision_backbone = getattr(backbone, "vision_backbone", None) + if vision_backbone is None or getattr( + vision_backbone, "_lerobot_memory_efficient_vision_backbone_patched", False + ): + return + + image_vit = getattr(vision_backbone, "image_vit", None) + transformer = getattr(image_vit, "transformer", None) + resblocks = getattr(transformer, "resblocks", None) + if image_vit is None or transformer is None or resblocks is None: + return + if not hasattr(vision_backbone, "vit_layers"): + return + + def _encode_image(self, images: Tensor) -> Tensor: + batch_size, num_crops, num_patches, patch_dim = images.shape + images = images.view(batch_size * num_crops, num_patches, patch_dim) + + x = self.image_vit.patch_embedding(images) + x = self.image_vit.add_pos_emb(x, self.image_vit.config.image_num_patch) + + needed_layers = {int(layer) for layer in self.vit_layers} + selected_features: dict[int, Tensor] = {} + use_checkpoint = bool( + self._lerobot_vision_gradient_checkpointing and self.training and torch.is_grad_enabled() + ) + for layer_idx, block in enumerate(self.image_vit.transformer.resblocks): + if use_checkpoint: + x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) + else: + x = block(x) + if layer_idx in needed_layers: + selected_features[layer_idx] = x + + if len(selected_features) != len(needed_layers): + missing = sorted(needed_layers - set(selected_features)) + raise RuntimeError(f"MolmoAct2 vision backbone did not produce requested layers: {missing}.") + + image_features = torch.cat([selected_features[int(layer)] for layer in self.vit_layers], dim=-1) + if self.num_prefix_tokens > 0: + image_features = image_features[:, 1:] + image_features = image_features.view(batch_size, num_crops, num_patches, -1) + return image_features + + vision_backbone.encode_image = types.MethodType(_encode_image, vision_backbone) + vision_backbone._lerobot_vision_gradient_checkpointing = bool(gradient_checkpointing) + vision_backbone._lerobot_memory_efficient_vision_backbone_patched = True + + +def _patch_training_kv_collection(backbone: Any) -> None: + """Expose per-layer VLM KV tensors without enabling HF autoregressive cache.""" + if getattr(backbone, "_lerobot_training_kv_collection_patched", False): + return + + transformer = getattr(backbone, "transformer", None) + blocks = getattr(transformer, "blocks", None) + if transformer is None or blocks is None: + raise RuntimeError("MolmoAct2 checkpoint does not expose a patchable text transformer.") + + original_transformer_forward = transformer.forward + from transformers.masking_utils import create_causal_mask + from transformers.modeling_outputs import BaseModelOutputWithPast + + def _patch_attention(attention: torch.nn.Module) -> None: + if getattr(attention, "_lerobot_training_kv_collection_patched", False): + return + + original_attention_forward = attention.forward + original_attention_func = getattr(original_attention_forward, "__func__", original_attention_forward) + attention_globals = getattr(original_attention_func, "__globals__", {}) + apply_rotary_pos_emb = attention_globals.get("apply_rotary_pos_emb") + repeat_kv = attention_globals.get("repeat_kv") + eager_attention_forward = attention_globals.get("eager_attention_forward") + all_attention_functions = attention_globals.get("ALL_ATTENTION_FUNCTIONS") + if ( + apply_rotary_pos_emb is None + or repeat_kv is None + or eager_attention_forward is None + or all_attention_functions is None + ): + raise RuntimeError("MolmoAct2 attention internals changed; cannot patch KV collection.") + + def _attention_forward( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Tensor | None, + past_key_values: Any | None = None, + cache_position: Tensor | None = None, + **kwargs, + ): + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + if not collect_layer_kv_states: + return original_attention_forward( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.att_proj(hidden_states) + query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1) + value_states = value_states.view(hidden_shape) + + if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type != "qwen3": + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "qwen3": + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + collected_key_states = key_states + collected_value_states = value_states + dropout_p = 0.0 if not self.training else self.attention_dropout + if self.config._attn_implementation == "sdpa" and ( + attention_mask is None or torch.is_tensor(attention_mask) + ): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=dropout_p, + is_causal=attention_mask is None, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = None + else: + attention_interface = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = all_attention_functions[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=dropout_p, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.attn_out(attn_output) + return attn_output, attn_weights, collected_key_states, collected_value_states + + attention.forward = types.MethodType(_attention_forward, attention) + attention._lerobot_training_kv_collection_patched = True + + def _patch_decoder_layer(layer: torch.nn.Module) -> None: + if getattr(layer, "_lerobot_training_kv_collection_patched", False): + return + + _patch_attention(layer.self_attn) + original_layer_forward = layer.forward + is_post_norm = "PostNorm" in layer.__class__.__name__ + + def _decoder_layer_forward( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + past_key_values: Any | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Tensor | None = None, + **kwargs, + ): + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + if not collect_layer_kv_states: + return original_layer_forward( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + residual = hidden_states + attn_input = hidden_states if is_post_norm else self.attn_norm(hidden_states) + attn_output, self_attn_weights, key_states, value_states = self.self_attn( + hidden_states=attn_input, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + collect_layer_kv_states=True, + **kwargs, + ) + if is_post_norm: + attn_output = self.attn_norm(attn_output) + hidden_states = residual + self.dropout(attn_output) + + residual = hidden_states + if is_post_norm: + hidden_states = self.mlp(hidden_states) + hidden_states = self.ff_norm(hidden_states) + else: + hidden_states = self.ff_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + return outputs + (key_states, value_states) + + layer.forward = types.MethodType(_decoder_layer_forward, layer) + layer._lerobot_training_kv_collection_patched = True + + for block in blocks: + _patch_decoder_layer(block) + + def _transformer_forward( + self, + input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + past_key_values: Any | None = None, + inputs_embeds: Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: Tensor | None = None, + **kwargs, + ): + collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) + if not collect_layer_kv_states: + return original_transformer_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + if past_key_values is not None: + raise ValueError("collect_layer_kv_states only supports full-sequence training forwards.") + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = False + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + inputs_embeds = self.wte(input_ids) + + if cache_position is None: + cache_position = torch.arange( + 0, + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if torch.is_tensor(attention_mask) and attention_mask.ndim == 4: + causal_mask_mapping = attention_mask + elif not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": position_ids, + } + causal_mask_mapping = create_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + if self.config.rope_scaling_layers is not None: + position_embeddings_mapping = { + "default": self.rotary_embs["default"](hidden_states, position_ids), + "scaling": self.rotary_embs["scaling"](hidden_states, position_ids), + } + else: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + collected_kv_states = [] + + for layer_idx, decoder_block in enumerate(self.blocks[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.config.rope_scaling_layers is not None: + position_embeddings_i = ( + position_embeddings_mapping["scaling"] + if layer_idx in self.config.rope_scaling_layers + else position_embeddings_mapping["default"] + ) + else: + position_embeddings_i = position_embeddings + + layer_outputs = decoder_block( + hidden_states, + position_embeddings=position_embeddings_i, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=None, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + collect_layer_kv_states=True, + **kwargs, + ) + hidden_states = layer_outputs[0] + + output_idx = 1 + if output_attentions: + all_self_attns += (layer_outputs[output_idx],) + output_idx += 1 + collected_kv_states.append((layer_outputs[output_idx], layer_outputs[output_idx + 1])) + + hidden_states = self.ln_f(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=tuple(collected_kv_states), + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + transformer.forward = types.MethodType(_transformer_forward, transformer) + backbone._lerobot_training_kv_collection_patched = True + + +class MolmoAct2Policy(PreTrainedPolicy): + config_class = MolmoAct2Config + name = "molmoact2" + + def __init__( + self, + config: MolmoAct2Config, + *inputs, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + dataset_meta: Any | None = None, + **kwargs, + ): + super().__init__(config, *inputs, **kwargs) + self._checkpoint_action_mode = self._load_saved_policy_action_mode() + self._apply_norm_tag_metadata() + self.config.validate_features() + del inputs, kwargs, dataset_stats, dataset_meta + self._action_queues: dict[int, deque[Tensor]] = defaultdict(deque) + self._rollout_action_generator: torch.Generator | None = None + self._rollout_task_key: tuple[Any, ...] | None = None + self._rollout_index_for_task = -1 + self.rtc_processor: RTCProcessor | None = None + self.action_tokenizer: Any | None = None + self._load_hf_model() + self._validate_inference_action_mode() + if self.config.enable_lora_vlm: + self._apply_lora_adapters() + self.init_rtc_processor() + + def _load_saved_policy_action_mode(self) -> str | None: + pretrained_path = getattr(self.config, "pretrained_path", None) + if pretrained_path is None: + return None + config_path = Path(pretrained_path) / "config.json" + if not config_path.exists(): + return None + try: + mode = json.loads(config_path.read_text()).get("action_mode") + except (OSError, json.JSONDecodeError): + return None + if mode in {"continuous", "discrete", "both"}: + return str(mode) + return None + + def _training_action_mode(self) -> str: + return getattr(self, "_checkpoint_action_mode", None) or self.config.action_mode + + def _validate_inference_action_mode(self) -> None: + requested_mode = self.config.inference_action_mode + if requested_mode is None: + return + training_mode = self._training_action_mode() + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run " + "continuous inference." + ) + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run " + "discrete inference. Train with action_mode='both' or action_mode='discrete' first." + ) + + def _apply_norm_tag_metadata(self) -> None: + if not str(self.config.norm_tag or "").strip(): + return + metadata = _load_hf_norm_metadata_for_tag( + self.config.checkpoint_path, + revision=self.config.checkpoint_revision, + force_download=bool(self.config.checkpoint_force_download), + norm_tag=self.config.norm_tag, + ) + if metadata.get("action_horizon") is not None: + self.config.chunk_size = int(metadata["action_horizon"]) + if metadata.get("n_action_steps") is not None: + self.config.n_action_steps = int(metadata["n_action_steps"]) + if not self.config.setup_type and metadata.get("setup_type") is not None: + self.config.setup_type = str(metadata["setup_type"]) + if not self.config.control_mode and metadata.get("control_mode") is not None: + self.config.control_mode = str(metadata["control_mode"]) + if not self.config.image_keys and isinstance(metadata.get("camera_keys"), list): + self.config.image_keys = [str(key) for key in metadata["camera_keys"]] + + def _load_hf_model(self) -> None: + require_package("transformers", extra="molmoact2") + from transformers import AutoModelForImageTextToText + + checkpoint_location = _resolve_checkpoint_location( + self.config.checkpoint_path, + revision=self.config.checkpoint_revision, + force_download=bool(self.config.checkpoint_force_download), + ) + model_dtype = _torch_dtype(self.config.model_dtype) + self.model = AutoModelForImageTextToText.from_pretrained( + checkpoint_location, + trust_remote_code=self.config.trust_remote_code, + dtype=model_dtype, + low_cpu_mem_usage=True, + token=_hf_token(), + ) + hf_max_action_dim = int(getattr(self.model.config, "max_action_dim", -1)) + if hf_max_action_dim != int(self.config.expected_max_action_dim): + raise ValueError( + "MolmoAct2 checkpoint max_action_dim mismatch: " + f"checkpoint={hf_max_action_dim}, expected={self.config.expected_max_action_dim}." + ) + if hf_max_action_dim != 32: + raise ValueError( + f"MolmoAct2 released checkpoints must have max_action_dim=32, got {hf_max_action_dim}." + ) + + if not hasattr(self.model.config, "max_action_horizon"): + raise ValueError("MolmoAct2 HF checkpoints must define `max_action_horizon`.") + self._override_loaded_max_action_horizon(int(self.config.chunk_size)) + + if not hasattr(self.model.config, "action_mode"): + raise ValueError( + "MolmoAct2 HF checkpoints must define `action_mode`. If this is a released " + "MolmoAct2 checkpoint, refresh the local Hub cache with " + "`policy.checkpoint_force_download=true` after the updated files are pushed." + ) + checkpoint_action_mode = str(self.model.config.action_mode) + if self.config.action_mode == "both" and checkpoint_action_mode != "both": + raise ValueError( + f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}." + ) + if self.config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}: + raise ValueError( + f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, " + f"got {checkpoint_action_mode!r}." + ) + if self.config.action_mode in {"continuous", "both"} and not bool( + getattr(self.model.config, "add_action_expert", False) + ): + raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.") + + if self.config.freeze_embedding: + self._freeze_input_embeddings() + if self.config.train_action_expert_only: + self._freeze_non_action_expert_parameters() + _patch_batched_image_attention_bias(self._backbone()) + _patch_leaf_safe_input_embedding_update(self._backbone()) + _patch_memory_efficient_vision_backbone( + self._backbone(), + gradient_checkpointing=bool(self.config.gradient_checkpointing), + ) + _patch_training_kv_collection(self._backbone()) + if self.config.gradient_checkpointing: + self._enable_gradient_checkpointing() + self.train(self.training) + + def reset(self) -> None: + self._action_queues = defaultdict(deque) + self._rollout_action_generator = None + + def _set_inference_cuda_graph_enabled(self, enabled: bool) -> None: + if not hasattr(self, "model"): + return + hf_model = self._hf_model() + enabled = bool(enabled and getattr(self.config, "enable_inference_cuda_graph", True)) + managers = [ + getattr(self._backbone(), "action_cuda_graph_manager", None), + getattr(hf_model, "action_cuda_graph_manager", None), + getattr(hf_model, "depth_decode_cuda_graph_manager", None), + ] + seen: set[int] = set() + for manager in managers: + if manager is None or id(manager) in seen: + continue + seen.add(id(manager)) + set_enabled = getattr(manager, "set_enabled", None) + if callable(set_enabled): + set_enabled(enabled) + + def init_rtc_processor(self) -> None: + self.rtc_processor = None + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _action_expert(self) -> torch.nn.Module: + return self._backbone()._require_action_expert() + + def _enable_gradient_checkpointing(self) -> None: + enable_gradient_checkpointing = getattr(self._hf_model(), "gradient_checkpointing_enable", None) + if callable(enable_gradient_checkpointing): + try: + enable_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": False}) + except TypeError: + enable_gradient_checkpointing() + else: + transformer = getattr(self._backbone(), "transformer", None) + if transformer is None: + raise RuntimeError("gradient_checkpointing=true, but MolmoAct2 exposes no text transformer.") + transformer.gradient_checkpointing = True + + transformer = getattr(self._backbone(), "transformer", None) + if transformer is not None: + transformer.gradient_checkpointing = True + + def _freeze_non_action_expert_parameters(self) -> None: + trainable_params = 0 + for name, param in self.named_parameters(): + param.requires_grad = "action_expert" in name + if param.requires_grad: + trainable_params += param.numel() + if trainable_params == 0: + raise RuntimeError("train_action_expert_only=true, but no action_expert parameters were found.") + + def _unfreeze_action_expert_parameters(self) -> None: + trainable_params = 0 + for name, param in self.named_parameters(): + if "action_expert" in name: + param.requires_grad_(True) + trainable_params += param.numel() + if trainable_params == 0: + raise RuntimeError("enable_lora_vlm=true, but no action_expert parameters were found.") + + def train(self, mode: bool = True): + super().train(mode) + if getattr(self.config, "train_action_expert_only", False) and hasattr(self, "model"): + self._hf_model().eval() + self._action_expert().train(mode) + self._set_inference_cuda_graph_enabled(not mode) + return self + + def _freeze_input_embeddings(self) -> None: + embedding_modules: list[torch.nn.Module] = [] + seen_module_ids: set[int] = set() + hf_model = self._hf_model() + for module in (hf_model, self._backbone()): + get_input_embeddings = getattr(module, "get_input_embeddings", None) + if not callable(get_input_embeddings): + continue + embeddings = get_input_embeddings() + if embeddings is None or id(embeddings) in seen_module_ids: + continue + embedding_modules.append(embeddings) + seen_module_ids.add(id(embeddings)) + + if not embedding_modules: + raise RuntimeError("freeze_embedding=true, but MolmoAct2 checkpoint exposes no input embeddings.") + + lm_head = getattr(hf_model, "lm_head", None) + lm_head_params = {id(param) for param in lm_head.parameters()} if lm_head is not None else set() + embedding_params = [param for embeddings in embedding_modules for param in embeddings.parameters()] + if any(id(param) in lm_head_params for param in embedding_params): + raise RuntimeError( + "freeze_embedding=true would also freeze lm_head because input embeddings and lm_head " + "share parameters in this checkpoint." + ) + for param in embedding_params: + param.requires_grad = False + + def get_optim_params(self) -> list[dict[str, Any]]: + vit_params: list[Tensor] = [] + connector_params: list[Tensor] = [] + action_expert_params: list[Tensor] = [] + vlm_params: list[Tensor] = [] + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + if "action_expert" in name: + action_expert_params.append(param) + elif any(part in name for part in ("image_pooling_2d", "image_projector")): + connector_params.append(param) + elif any(part in name for part in ("vision", "image_encoder", "vit")): + vit_params.append(param) + elif any(part in name for part in ("multi_modal_projector", "connector", "mm_projector")): + connector_params.append(param) + else: + vlm_params.append(param) + + vlm_lr = 5e-5 if self.config.enable_lora_vlm else self.config.optimizer_lr + vit_lr = 5e-5 if self.config.enable_lora_vlm else self.config.optimizer_vit_lr + connector_lr = 5e-5 if self.config.enable_lora_vlm else self.config.optimizer_connector_lr + + groups: list[dict[str, Any]] = [] + if vlm_params: + groups.append({"params": vlm_params, "lr": vlm_lr}) + if vit_params: + groups.append({"params": vit_params, "lr": vit_lr}) + if connector_params: + groups.append({"params": connector_params, "lr": connector_lr}) + if action_expert_params: + groups.append({"params": action_expert_params, "lr": self.config.optimizer_action_expert_lr}) + return groups + + def _model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + compute_dtype = _torch_dtype(self.config.model_dtype) + return { + key: value.to(dtype=compute_dtype) if value.is_floating_point() else value + for key, value in batch.items() + if key in _MODEL_INPUT_KEYS and value is not None + } + + def _output_action_dim(self, batch: dict[str, Tensor]) -> int: + action_feature = self.config.output_features.get(ACTION) + if action_feature is not None and action_feature.shape: + action_dim = int(action_feature.shape[0]) + if action_dim > 0: + return action_dim + + action_dim_is_pad = batch.get("action_dim_is_pad") + if action_dim_is_pad is not None: + valid_counts = (~action_dim_is_pad.to(dtype=torch.bool)).sum(dim=-1) + if bool((valid_counts == valid_counts[0]).all()) and int(valid_counts[0]) > 0: + return int(valid_counts[0]) + + raise RuntimeError("MolmoAct2 inference requires a positive action dimension in output_features.") + + def _hf_model(self): + base_model = getattr(self.model, "base_model", None) + wrapped_model = getattr(base_model, "model", None) if base_model is not None else None + return wrapped_model if wrapped_model is not None else self.model + + def _backbone(self): + return self._hf_model().model + + def _override_loaded_max_action_horizon(self, action_horizon: int) -> None: + if action_horizon < 1: + raise ValueError(f"action_horizon must be >= 1, got {action_horizon}.") + hf_model = self._hf_model() + for cfg in (getattr(hf_model, "config", None), getattr(self._backbone(), "config", None)): + if cfg is not None: + cfg.max_action_horizon = int(action_horizon) + + def _generation_action_horizon(self) -> int: + chunk_size = getattr(self.config, "chunk_size", None) + if chunk_size is not None: + return int(chunk_size) + hf_model = self._hf_model() + for cfg in (getattr(hf_model, "config", None), getattr(self._backbone(), "config", None)): + if cfg is None: + continue + value = getattr(cfg, "max_action_horizon", None) + if value is not None: + return int(value) + raise RuntimeError("MolmoAct2 could not resolve an action generation horizon.") + + @staticmethod + def _mask_discrete_action_spans( + *, + input_ids: Tensor, + mask: Tensor, + start_token_id: int | None, + end_token_id: int | None, + ) -> Tensor: + if start_token_id is None or end_token_id is None: + return mask + mask = mask.clone() + for batch_idx in range(input_ids.shape[0]): + row = input_ids[batch_idx] + starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist() + ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist() + end_ptr = 0 + for start in starts: + while end_ptr < len(ends) and ends[end_ptr] < start: + end_ptr += 1 + if end_ptr >= len(ends): + mask[batch_idx, start:] = False + break + end = int(ends[end_ptr]) + mask[batch_idx, start : end + 1] = False + end_ptr += 1 + return mask + + def _encoder_attention_mask_for_action_expert( + self, + *, + input_ids: Tensor | None, + attention_mask: Tensor | None, + ) -> Tensor | None: + backbone = self._backbone() + get_encoder_attention_mask = getattr(backbone, "_get_encoder_attention_mask", None) + if callable(get_encoder_attention_mask): + mask = get_encoder_attention_mask(input_ids, attention_mask) + elif attention_mask is not None: + mask = attention_mask.to(dtype=torch.bool) + elif input_ids is not None: + mask = input_ids != -1 + else: + return None + + if getattr(self.config, "action_mode", None) != "both" or input_ids is None or mask is None: + return mask + + mask = mask.to(dtype=torch.bool).clone() + eos_token_id = getattr(self.model.config, "eos_token_id", None) + if eos_token_id is not None: + mask &= input_ids != int(eos_token_id) + return self._mask_discrete_action_spans( + input_ids=input_ids, + mask=mask, + start_token_id=getattr(self.model.config, "action_start_token_id", None), + end_token_id=getattr(self.model.config, "action_end_token_id", None), + ) + + @staticmethod + def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]: + attention_mask = model_inputs.get("attention_mask") + if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): + model_inputs = dict(model_inputs) + model_inputs.pop("attention_mask", None) + return model_inputs + + def _load_discrete_action_tokenizer(self) -> Any: + if self.action_tokenizer is None: + require_package("transformers", extra="molmoact2") + from transformers import AutoProcessor + + self.action_tokenizer = AutoProcessor.from_pretrained( + self.config.discrete_action_tokenizer, + trust_remote_code=self.config.trust_remote_code, + token=_hf_token(), + ) + return self.action_tokenizer + + def _resolve_inference_action_mode(self, requested_mode: str | None) -> str: + training_mode = self._training_action_mode() + if requested_mode is None: + requested_mode = self.config.inference_action_mode + if requested_mode is None: + raise ValueError( + "MolmoAct2 inference requires `inference_action_mode` to be set explicitly " + "to either 'continuous' or 'discrete'." + ) + if requested_mode not in {"continuous", "discrete"}: + raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.") + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.") + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.") + return requested_mode + + @staticmethod + def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int: + seed = 0 + for idx in range(batch_size): + seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1) + return seed + + @staticmethod + def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None: + task = batch.get("task") + if task is None: + task = batch.get("observation.language") + if task is None: + return None + if isinstance(task, str): + return (task,) + if isinstance(task, (list, tuple)): + return tuple(str(item) for item in task) + return (str(task),) + + def _rollout_generator_for_inputs( + self, + batch: dict[str, Any], + *, + batch_size: int, + device: torch.device, + ) -> torch.Generator | None: + if not bool(getattr(self.config, "per_episode_seed", False)): + return None + if self._rollout_action_generator is not None: + return self._rollout_action_generator + + task_signature = self._rollout_task_signature(batch) + if task_signature != self._rollout_task_key: + self._rollout_task_key = task_signature + self._rollout_index_for_task = 0 + else: + self._rollout_index_for_task += 1 + + base_seed = int(getattr(self.config, "eval_seed", None) or 0) + first_seed = base_seed + self._rollout_index_for_task * batch_size + generator_device = ( + device if device.type == "cuda" and torch.cuda.is_available() else torch.device("cpu") + ) + generator = torch.Generator(device=generator_device) + generator.manual_seed(self._combine_rollout_seeds(first_seed, batch_size)) + self._rollout_action_generator = generator + return generator + + @staticmethod + def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None: + if mask is None: + return None + return ( + mask.unsqueeze(1) + .expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1))) + .reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:]) + ) + + @staticmethod + def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None: + if action_dim_is_pad is None: + return None + mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) + if mask.ndim == 1: + mask = mask.unsqueeze(0) + if mask.shape[-1] != target.shape[-1]: + raise ValueError( + f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}." + ) + if mask.shape[0] == 1 and target.shape[0] != 1: + mask = mask.expand(target.shape[0], -1) + if mask.shape[0] != target.shape[0]: + raise ValueError( + f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}." + ) + while mask.ndim < target.ndim: + mask = mask.unsqueeze(1) + return mask + + @classmethod + def _mask_action_dim_tensor(cls, tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: + if not cls._mask_enabled_static(action_dim_is_pad): + return tensor + valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad) + if valid_mask is None: + return tensor + return tensor.masked_fill(~valid_mask, 0) + + @staticmethod + def _mask_enabled_static(action_dim_is_pad: Tensor | None) -> bool: + return action_dim_is_pad is not None + + @classmethod + def _apply_action_dim_padding_mask(cls, loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: + valid_mask = cls._action_dim_valid_mask(loss, action_dim_is_pad) + if valid_mask is None: + return loss + valid = valid_mask.to(dtype=loss.dtype) + denom = valid.sum(dim=-1).clamp_min(1.0) + return (loss * valid).sum(dim=-1) / denom + + @staticmethod + def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor: + if action_horizon_is_pad is None: + return loss + valid_action = ( + (~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1) + ) + return loss * valid_action + + def _prepare_flow_matching_tensors( + self, + *, + actions: Tensor, + action_dim_is_pad: Tensor | None, + timesteps: Tensor | None = None, + noise: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + action_expert = self._backbone()._require_action_expert() + action_dtype = next(action_expert.parameters()).dtype + actions = actions.to(dtype=action_dtype) + batch_size = int(actions.shape[0]) + device = actions.device + num_flow_timesteps = max(1, int(self.config.num_flow_timesteps)) + + if timesteps is None: + timesteps = ( + _sample_beta_timesteps( + batch_size=batch_size * num_flow_timesteps, + device=device, + cutoff=self.config.flow_matching_cutoff, + time_offset=self.config.flow_matching_time_offset, + time_scale=self.config.flow_matching_time_scale, + alpha=self.config.flow_matching_beta_alpha, + beta=self.config.flow_matching_beta_beta, + ) + .to(dtype=action_dtype) + .view(batch_size, num_flow_timesteps) + ) + else: + expected_timesteps_shape = (batch_size, num_flow_timesteps) + timesteps = timesteps.to(device=device, dtype=action_dtype) + if tuple(timesteps.shape) != expected_timesteps_shape: + raise ValueError( + f"flow timesteps must have shape {expected_timesteps_shape}, got {tuple(timesteps.shape)}." + ) + + if self.config.mask_action_dim_padding: + actions = self._mask_action_dim_tensor(actions, action_dim_is_pad) + + expected_noise_shape = (batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2]) + if noise is None: + noise = torch.randn(*expected_noise_shape, device=device, dtype=actions.dtype) + else: + noise = noise.to(device=device, dtype=actions.dtype) + if tuple(noise.shape) != expected_noise_shape: + raise ValueError( + f"flow noise must have shape {expected_noise_shape}, got {tuple(noise.shape)}." + ) + if self.config.mask_action_dim_padding: + noise = self._mask_action_dim_tensor(noise, action_dim_is_pad) + + t_broadcast = timesteps.view(batch_size, num_flow_timesteps, 1, 1) + actions_expanded = actions.unsqueeze(1).expand(-1, num_flow_timesteps, -1, -1) + xt = (1.0 - t_broadcast) * noise + t_broadcast * actions_expanded + target_velocity = actions_expanded - noise + return actions, timesteps, xt, target_velocity + + def _prepare_joint_training_backbone_inputs( + self, + model_inputs: dict[str, Tensor], + ) -> tuple[Tensor, Tensor | dict[str, Any], Tensor, Tensor]: + backbone = self._backbone() + input_ids = model_inputs.get("input_ids") + inputs_embeds = model_inputs.get("inputs_embeds") + if (input_ids is None) == (inputs_embeds is None): + raise ValueError( + "MolmoAct2 joint flow training requires exactly one of input_ids or inputs_embeds." + ) + + images = None + token_pooling = None + merge_visual_inputs = getattr(backbone, "merge_visual_inputs", None) + if callable(merge_visual_inputs): + images, token_pooling = merge_visual_inputs( + input_ids=input_ids, + pixel_values=model_inputs.get("pixel_values"), + image_token_pooling=model_inputs.get("image_token_pooling"), + image_grids=model_inputs.get("image_grids"), + image_num_crops=model_inputs.get("image_num_crops"), + pixel_values_videos=model_inputs.get("pixel_values_videos"), + video_token_pooling=model_inputs.get("video_token_pooling"), + video_grids=model_inputs.get("video_grids"), + ) + elif ( + model_inputs.get("pixel_values") is not None + or model_inputs.get("pixel_values_videos") is not None + ): + raise RuntimeError("MolmoAct2 checkpoint does not expose merge_visual_inputs for joint training.") + + if images is not None and inputs_embeds is not None: + raise ValueError("MolmoAct2 joint flow training cannot combine inputs_embeds with visual inputs.") + if inputs_embeds is None: + inputs_embeds, _image_features = backbone.build_input_embeddings(input_ids, images, token_pooling) + + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + position_ids = model_inputs.get("position_ids") + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + attention_mask = model_inputs.get("attention_mask") + if isinstance(attention_mask, dict): + causal_mask_mapping = attention_mask + else: + causal_mask_mapping = backbone._build_native_attention_bias( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + token_type_ids=model_inputs.get("token_type_ids"), + past_key_values=None, + ) + return inputs_embeds, causal_mask_mapping, position_ids, cache_position + + @staticmethod + def _decoder_layer_kv_outputs( + layer_outputs: tuple[Any, ...], *, output_attentions: bool + ) -> tuple[Tensor, Tensor]: + output_idx = 2 if output_attentions else 1 + return layer_outputs[output_idx], layer_outputs[output_idx + 1] + + @staticmethod + def _action_time_conditioning(action_expert: torch.nn.Module, timesteps: Tensor) -> Tensor: + time_conditioning = getattr(action_expert, "_time_conditioning", None) + if callable(time_conditioning): + return time_conditioning(timesteps) + return action_expert.time_embed(timesteps) + + def _compute_flow_matching_loss_joint_per_layer( + self, + *, + batch: dict[str, Tensor], + model_inputs: dict[str, Tensor], + timesteps: Tensor | None = None, + noise: Tensor | None = None, + reduction: str = "mean", + ) -> tuple[Tensor, Tensor]: + if reduction not in {"mean", "none"}: + raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") + backbone = self._backbone() + transformer = getattr(backbone, "transformer", None) + action_expert = backbone._require_action_expert() + if transformer is None: + raise RuntimeError("MolmoAct2 joint flow training requires a patchable text transformer.") + if len(action_expert.blocks) != int(transformer.config.num_hidden_layers): + raise RuntimeError( + "MolmoAct2 joint flow training requires one action expert block per text transformer layer." + ) + + actions, timesteps, xt, target_velocity = self._prepare_flow_matching_tensors( + actions=batch[ACTION], + action_dim_is_pad=batch.get("action_dim_is_pad"), + timesteps=timesteps, + noise=noise, + ) + num_flow_timesteps = max(1, int(self.config.num_flow_timesteps)) + batch_size = int(actions.shape[0]) + device = actions.device + xt_flat = xt.reshape(batch_size * num_flow_timesteps, actions.shape[1], actions.shape[2]) + timesteps_flat = timesteps.reshape(batch_size * num_flow_timesteps) + + hidden_states, causal_mask_mapping, position_ids, cache_position = ( + self._prepare_joint_training_backbone_inputs(model_inputs) + ) + if hidden_states.shape[0] != batch_size: + raise ValueError( + f"Backbone batch size {hidden_states.shape[0]} does not match action batch size {batch_size}." + ) + + encoder_attention_mask = self._encoder_attention_mask_for_action_expert( + input_ids=model_inputs.get("input_ids"), + attention_mask=model_inputs.get("attention_mask"), + ) + action_attention_mask = None + if batch.get("action_horizon_is_pad") is not None: + action_attention_mask = ~batch["action_horizon_is_pad"].to(device=device, dtype=torch.bool) + + valid_action = None + if action_attention_mask is not None: + valid_action = action_attention_mask.to(device=device, dtype=actions.dtype).unsqueeze(-1) + valid_action = self._expand_mask(valid_action, num_flow_timesteps) + + rope_cache = None + if len(action_expert.blocks) > 0 and action_expert.blocks[0].self_attn.rope is not None: + rope_cache = action_expert.blocks[0].self_attn.rope.build_cache( + seq_len=actions.shape[1], + device=device, + dtype=actions.dtype, + ) + + cross_mask = action_expert._build_cross_attention_mask( + encoder_attention_mask, + batch_size, + actions.dtype, + ) + cross_mask = self._expand_mask(cross_mask, num_flow_timesteps) + self_mask = action_expert._build_self_attention_mask( + action_attention_mask, + actions.shape[1], + device, + actions.dtype, + ) + self_mask = self._expand_mask(self_mask, num_flow_timesteps) + + conditioning = self._action_time_conditioning(action_expert, timesteps_flat) + action_hidden = action_expert.action_embed(xt_flat) + if valid_action is not None: + action_hidden = action_hidden * valid_action + + if transformer.config.rope_scaling_layers is not None: + position_embeddings_mapping = { + "default": transformer.rotary_embs["default"](hidden_states, position_ids), + "scaling": transformer.rotary_embs["scaling"](hidden_states, position_ids), + } + else: + position_embeddings = transformer.rotary_emb(hidden_states, position_ids) + + use_gradient_checkpointing = bool( + getattr(self.config, "gradient_checkpointing", False) + and self.training + and torch.is_grad_enabled() + ) + + def run_layer( + layer_idx: int, layer_hidden: Tensor, layer_action_hidden: Tensor + ) -> tuple[Tensor, Tensor]: + decoder_block = transformer.blocks[layer_idx] + action_block = action_expert.blocks[layer_idx] + if transformer.config.rope_scaling_layers is not None: + position_embeddings_i = ( + position_embeddings_mapping["scaling"] + if layer_idx in transformer.config.rope_scaling_layers + else position_embeddings_mapping["default"] + ) + else: + position_embeddings_i = position_embeddings + + layer_outputs = decoder_block( + layer_hidden, + position_embeddings=position_embeddings_i, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=None, + output_attentions=False, + use_cache=False, + cache_position=cache_position, + collect_layer_kv_states=True, + ) + next_hidden = layer_outputs[0] + key_states, value_states = self._decoder_layer_kv_outputs(layer_outputs, output_attentions=False) + key_states = backbone._cache_to_sequence(key_states) + value_states = backbone._cache_to_sequence(value_states) + if self.config.enable_knowledge_insulation: + key_states = key_states.detach() + value_states = value_states.detach() + + k_ctx = action_expert._project_kv_tensor(key_states, action_expert.context_k_proj) + v_ctx = action_expert._project_kv_tensor(value_states, action_expert.context_v_proj) + k_norm = action_block.cross_attn.k_norm + if k_norm is not None: + k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2) + if num_flow_timesteps != 1: + k_ctx = self._expand_mask(k_ctx, num_flow_timesteps) + v_ctx = self._expand_mask(v_ctx, num_flow_timesteps) + + next_action_hidden = action_block( + layer_action_hidden, + conditioning, + cross_kv=(k_ctx, v_ctx), + self_attn_mask=self_mask, + attn_mask=cross_mask, + is_causal=action_expert.config.causal_attn, + modulation=None, + rope_cache=rope_cache, + ) + if valid_action is not None: + next_action_hidden = next_action_hidden * valid_action + return next_hidden, next_action_hidden + + for layer_idx in range(int(transformer.config.num_hidden_layers)): + if use_gradient_checkpointing: + hidden_states, action_hidden = torch.utils.checkpoint.checkpoint( + lambda layer_hidden, layer_action_hidden, idx=layer_idx: run_layer( + idx, + layer_hidden, + layer_action_hidden, + ), + hidden_states, + action_hidden, + use_reentrant=False, + ) + else: + hidden_states, action_hidden = run_layer(layer_idx, hidden_states, action_hidden) + + hidden_states = transformer.ln_f(hidden_states) + pred_velocity = action_expert.final_layer(action_hidden, conditioning) + if valid_action is not None: + pred_velocity = pred_velocity * valid_action + pred_velocity = pred_velocity.reshape( + batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2] + ) + + loss = F.mse_loss(pred_velocity, target_velocity, reduction="none") + loss = self._apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad")) + if self.config.mask_action_dim_padding: + loss = self._apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad")) + loss = loss.reshape(batch_size, -1).mean(dim=1) + if reduction == "mean": + loss = loss.mean() + return loss, hidden_states + + def _discrete_token_weights(self, valid_positions: Tensor) -> Tensor | None: + mode = self.config.discrete_loss_token_weighting + if mode in {"none", "token", "root_subsegments"}: + return None + if mode != "root_subsegments_root_tokens" and mode != "root_tokens": + raise ValueError(f"Unsupported discrete_loss_token_weighting={mode!r}.") + + token_counts = valid_positions.sum(dim=1).to(dtype=torch.float32) + example_weights = torch.zeros_like(token_counts) + nonempty = token_counts > 0 + example_weights[nonempty] = 2.0 / torch.sqrt(token_counts[nonempty]) + return example_weights[:, None].expand_as(valid_positions)[valid_positions].to(dtype=torch.float32) + + @staticmethod + def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor: + if weights is None: + return values.mean() + weights = weights.to(device=values.device, dtype=values.dtype) + return torch.dot(values, weights) / weights.sum().clamp_min(1.0) + + @staticmethod + def _weighted_per_example( + values: Tensor, + weights: Tensor | None, + example_indices: Tensor, + batch_size: int, + ) -> Tensor: + values = values.float() + if weights is None: + weights = torch.ones_like(values) + else: + weights = weights.to(device=values.device, dtype=values.dtype) + loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) + weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) + loss_sum.scatter_add_(0, example_indices, values * weights) + weight_sum.scatter_add_(0, example_indices, weights) + global_weight_sum = weight_sum.sum().clamp_min(1.0) + return loss_sum * float(batch_size) / global_weight_sum + + def _discrete_loss_from_backbone_outputs( + self, + batch: dict[str, Tensor], + outputs: Any, + reduction: str = "mean", + ) -> tuple[Tensor, Tensor | None]: + if reduction not in {"mean", "none"}: + raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") + labels = batch.get("labels") + if labels is None: + raise RuntimeError("MolmoAct2 discrete training requires labels.") + hidden_states = outputs.last_hidden_state + if hidden_states is None: + raise RuntimeError("MolmoAct2 backbone did not return last_hidden_state.") + + ignore_index = -100 + shift_labels = F.pad(labels, (0, 1), value=ignore_index)[..., 1:].contiguous() + valid_positions = shift_labels != ignore_index + if not bool(valid_positions.any()): + raise RuntimeError("MolmoAct2 discrete training labels contain no valid action tokens.") + + hidden_size = hidden_states.shape[-1] + selected_hidden = hidden_states.reshape(-1, hidden_size)[valid_positions.reshape(-1)] + selected_labels = shift_labels.reshape(-1)[valid_positions.reshape(-1)].to( + device=hidden_states.device + ) + logits = F.linear(selected_hidden, self.model.lm_head.weight).float() + log_z = logits.logsumexp(dim=-1) + target_logits = logits.gather(dim=-1, index=selected_labels[:, None]).squeeze(-1) + token_ce_loss = log_z - target_logits + token_weights = self._discrete_token_weights(valid_positions) + if reduction == "none": + example_indices = valid_positions.nonzero(as_tuple=False)[:, 0].to(device=hidden_states.device) + ce_loss = self._weighted_per_example( + token_ce_loss, + token_weights, + example_indices, + int(labels.shape[0]), + ) + else: + ce_loss = self._weighted_mean(token_ce_loss, token_weights) + if not self.config.softmax_auxiliary_loss: + return ce_loss, None + + if reduction == "none": + z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_per_example( + log_z.pow(2), + token_weights, + example_indices, + int(labels.shape[0]), + ) + else: + z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_mean( + log_z.pow(2), token_weights + ) + return ce_loss, z_loss + + @staticmethod + def _extract_discrete_token_bins( + generated_ids: list[int], + start_token_id: int, + end_token_id: int, + token_id_to_bin: dict[int, int], + ) -> list[int]: + start_idx = None + end_idx = None + for idx, token_id in enumerate(generated_ids): + if token_id == start_token_id: + start_idx = idx + break + if start_idx is not None: + for idx in range(start_idx + 1, len(generated_ids)): + if generated_ids[idx] == end_token_id: + end_idx = idx + break + span_start = 0 if start_idx is None else start_idx + 1 + span_end = len(generated_ids) if end_idx is None else end_idx + return [ + int(token_id_to_bin[token_id]) + for token_id in generated_ids[span_start:span_end] + if token_id in token_id_to_bin + ] + + def _action_token_id_to_bin(self) -> dict[int, int]: + method = getattr(self.model, "_action_token_id_to_bin", None) + if callable(method): + return dict(method()) + start = getattr(self.model.config, "action_token_start_id", None) + num_tokens = int(getattr(self.model.config, "num_action_tokens", 0) or 0) + if start is None or num_tokens <= 0: + return {} + return {int(start) + idx: idx for idx in range(num_tokens)} + + def _require_discrete_eos_token_id(self) -> int: + method = getattr(self.model, "_require_eos_token_id", None) + if callable(method): + return int(method()) + eos_token_id = getattr(self.model.config, "eos_token_id", None) + if eos_token_id is None and getattr(self.model, "generation_config", None) is not None: + eos_token_id = getattr(self.model.generation_config, "eos_token_id", None) + if isinstance(eos_token_id, (list, tuple)): + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None: + raise RuntimeError("Discrete action generation requires eos_token_id in the checkpoint config.") + return int(eos_token_id) + + def _discrete_generation_max_steps(self) -> int: + if self.config.discrete_generation_max_steps is not None: + return int(self.config.discrete_generation_max_steps) + return max(1, self._generation_action_horizon() * 16) + + def _continue_discrete_generation_from_output( + self, + initial_output: Any, + *, + past_key_values: Any | None, + attention_mask: Tensor | None, + end_token_id: int, + max_steps: int, + attention_bias: Tensor | None = None, + ) -> Tensor: + consume_generation_tokens = getattr(self.model, "_consume_generation_tokens", None) + ar_decode_step = getattr(self.model, "_run_ar_decode_step", None) + if ar_decode_step is None: + ar_decode_step = getattr(self.model, "_run_depth_decode_step", None) + if attention_bias is None and not callable(consume_generation_tokens): + raise RuntimeError("MolmoAct2 checkpoint does not expose discrete token generation helpers.") + if attention_bias is not None and not callable(ar_decode_step): + raise RuntimeError("MolmoAct2 checkpoint does not expose graph-backed AR decode helpers.") + + generated_tokens: list[Tensor] = [] + current_output = initial_output + current_past_key_values = past_key_values + current_attention_mask = attention_mask + hit_end = False + for _ in range(int(max_steps)): + next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) + generated_tokens.append(next_token) + if bool((next_token == int(end_token_id)).all()): + hit_end = True + break + if attention_bias is None: + current_output, current_attention_mask = consume_generation_tokens( + next_token, + past_key_values=current_past_key_values, + attention_mask=current_attention_mask, + ) + current_past_key_values = current_output.past_key_values + else: + last_hidden, current_past_key_values = ar_decode_step( + next_token, + past_key_values=current_past_key_values, + attention_bias=attention_bias, + ) + current_output = types.SimpleNamespace( + logits=self.model.lm_head(last_hidden), + past_key_values=current_past_key_values, + ) + if not generated_tokens: + raise RuntimeError("Discrete continuation generated no tokens.") + if not hit_end: + raise RuntimeError( + f"Discrete continuation did not emit end token {int(end_token_id)} within {int(max_steps)} steps." + ) + return torch.stack(generated_tokens, dim=1) + + def _make_discrete_ar_graph_decode_inputs( + self, + model_inputs: dict[str, Tensor], + *, + max_steps: int, + ) -> tuple[Any | None, Tensor | None]: + if not bool(getattr(self.config, "enable_inference_cuda_graph", False)): + return None, None + if self.training or self.model.training: + return None, None + ar_decode_step = getattr(self.model, "_run_ar_decode_step", None) + if ar_decode_step is None: + ar_decode_step = getattr(self.model, "_run_depth_decode_step", None) + make_attention_bias = getattr(self.model, "_make_depth_decode_attention_bias", None) + if not callable(ar_decode_step) or not callable(make_attention_bias): + return None, None + + make_static_cache = getattr(self.model, "_make_ar_decode_static_cache", None) + if callable(make_static_cache): + static_cache = make_static_cache(model_inputs, max_steps=max_steps) + else: + graph_manager = getattr(self.model, "depth_decode_cuda_graph_manager", None) + make_manager_static_cache = getattr(graph_manager, "make_static_cache", None) + if not callable(make_manager_static_cache): + return None, None + prompt_len = int(model_inputs["input_ids"].shape[1]) + static_cache = make_manager_static_cache(max_cache_len=prompt_len + max(1, int(max_steps))) + + attention_bias = make_attention_bias(model_inputs, static_cache) + return static_cache, attention_bias + + def _decode_discrete_action_chunk(self, generated_token_ids: Tensor, *, action_dim: int) -> Tensor: + if ( + getattr(self.model.config, "action_start_token_id", None) is None + or getattr(self.model.config, "action_end_token_id", None) is None + ): + raise RuntimeError("Discrete action generation requires / token IDs.") + token_id_to_bin = self._action_token_id_to_bin() + if not token_id_to_bin: + raise RuntimeError( + "Discrete action generation requires indexed action tokens in the checkpoint config." + ) + + action_tokenizer = self._load_discrete_action_tokenizer() + if generated_token_ids.ndim == 1: + generated_token_ids = generated_token_ids.unsqueeze(0) + if generated_token_ids.ndim == 3: + generated_token_ids = generated_token_ids[:, 0, :] + if generated_token_ids.ndim != 2: + raise ValueError(f"Unexpected generated token tensor shape {tuple(generated_token_ids.shape)}.") + + chunks: list[Tensor] = [] + for token_row in generated_token_ids: + generated_ids = [int(token_id) for token_id in token_row.detach().cpu().tolist()] + discrete_token_ids = self._extract_discrete_token_bins( + generated_ids, + int(self.model.config.action_start_token_id), + int(self.model.config.action_end_token_id), + token_id_to_bin, + ) + if not discrete_token_ids: + raise RuntimeError( + "Model generated no decodable action tokens between /." + ) + try: + decoded = action_tokenizer.decode( + [discrete_token_ids], + time_horizon=self._generation_action_horizon(), + action_dim=int(action_dim), + ) + except TypeError: + decoded = action_tokenizer.decode([discrete_token_ids]) + action_chunk = np.asarray(decoded, dtype=np.float32) + if action_chunk.ndim == 1: + action_chunk = action_chunk[None, :] + elif action_chunk.ndim == 3: + if int(action_chunk.shape[0]) != 1: + action_chunk = action_chunk.reshape(action_chunk.shape[-2], action_chunk.shape[-1]) + else: + action_chunk = action_chunk[0] + elif action_chunk.ndim > 3: + action_chunk = action_chunk.reshape(action_chunk.shape[-2], action_chunk.shape[-1]) + if action_chunk.ndim != 2: + raise RuntimeError(f"Decoded action chunk has unexpected shape {action_chunk.shape}.") + chunks.append(torch.as_tensor(action_chunk, device=token_row.device, dtype=torch.float32)) + return torch.stack(chunks, dim=0) + + def _generate_discrete_actions_from_inputs( + self, + *, + model_inputs: dict[str, Tensor], + action_dim: int, + ) -> Tensor: + model_inputs = self._drop_trivial_attention_mask(model_inputs) + max_steps = self._discrete_generation_max_steps() + static_cache, attention_bias = self._make_discrete_ar_graph_decode_inputs( + model_inputs, + max_steps=max_steps, + ) + prefill_kwargs: dict[str, Any] = {} + if static_cache is not None: + prefill_kwargs["past_key_values"] = static_cache + prefill_output = self.model( + **model_inputs, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + **prefill_kwargs, + ) + generated_token_ids = self._continue_discrete_generation_from_output( + prefill_output, + past_key_values=prefill_output.past_key_values, + attention_mask=model_inputs.get("attention_mask"), + end_token_id=self._require_discrete_eos_token_id(), + max_steps=max_steps, + attention_bias=attention_bias, + ) + return self._decode_discrete_action_chunk(generated_token_ids, action_dim=action_dim) + + def _generate_actions_from_inputs_with_rtc( + self, + *, + model_inputs: dict[str, Tensor], + action_dim_is_pad: Tensor | None, + num_steps: int | None, + generator: torch.Generator | None, + inference_delay: int | None, + prev_chunk_left_over: Tensor | None, + execution_horizon: int | None, + ) -> Tensor: + backbone = self._backbone() + action_expert = self._action_expert() + outputs = backbone( + **model_inputs, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + ) + encoder_kv_states = backbone._extract_kv_states(outputs.past_key_values) + encoder_attention_mask = self._encoder_attention_mask_for_action_expert( + input_ids=model_inputs.get("input_ids"), + attention_mask=model_inputs.get("attention_mask"), + ) + depth_gate, depth_mask = backbone._depth_gate_from_condition( + input_ids=model_inputs.get("input_ids"), + encoder_attention_mask=encoder_attention_mask, + layer_kv_states=encoder_kv_states, + ) + encoder_kv_states = backbone._apply_depth_gate_to_layer_kv_states( + encoder_kv_states, + depth_mask, + depth_gate, + ) + + steps = int(num_steps or backbone.config.flow_matching_num_steps) + if steps <= 0: + raise ValueError(f"num_steps must be >= 1, got {steps}.") + source_tensor = encoder_kv_states[0][0] + batch_size = int(source_tensor.shape[0]) + device = source_tensor.device + trajectory = torch.randn( + batch_size, + self._generation_action_horizon(), + int(backbone.config.max_action_dim), + device=device, + dtype=torch.float32, + generator=generator, + ) + if self.config.mask_action_dim_padding: + trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad) + + action_context = action_expert.prepare_context( + encoder_kv_states=encoder_kv_states, + encoder_attention_mask=encoder_attention_mask, + state_embeddings=None, + batch_size=batch_size, + seq_len=trajectory.shape[1], + device=device, + dtype=trajectory.dtype, + ) + flow_timesteps = [ + torch.full((batch_size,), idx / steps, device=device, dtype=trajectory.dtype) + for idx in range(steps) + ] + modulation_cache = action_expert.get_or_prepare_modulation_cache( + flow_timesteps, + cache_key=(steps, batch_size, device, trajectory.dtype), + ) + + dt = 1.0 / steps + mask_enabled = self.config.mask_action_dim_padding + for idx, flow_timestep in enumerate(flow_timesteps): + modulation = modulation_cache[idx] + + def denoise_step(input_trajectory: Tensor, step_modulation=modulation) -> Tensor: + velocity = action_expert.forward_with_context( + input_trajectory, + step_modulation.conditioning, + context=action_context, + modulation=step_modulation, + ) + if mask_enabled: + velocity = self._mask_action_dim_tensor(velocity, action_dim_is_pad) + return velocity + + if self._rtc_enabled(): + if self.rtc_processor is None: + raise RuntimeError("RTC is enabled but rtc_processor is not initialized.") + + def rtc_denoise_step(input_trajectory: Tensor) -> Tensor: + return -denoise_step(input_trajectory) + + rtc_time = 1.0 - float(flow_timestep[0].item()) + rtc_velocity = self.rtc_processor.denoise_step( + x_t=trajectory, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=int(inference_delay or 0), + time=rtc_time, + original_denoise_step_partial=rtc_denoise_step, + execution_horizon=execution_horizon, + ) + velocity = -rtc_velocity + else: + velocity = denoise_step(trajectory) + + trajectory = trajectory + dt * velocity + if mask_enabled: + trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad) + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=float(flow_timestep[0].item()), x_t=trajectory, v_t=velocity) + + return trajectory + + def forward( + self, + batch: dict[str, Tensor], + reduction: str = "mean", + ) -> tuple[Tensor, dict[str, Any]]: + if reduction not in {"mean", "none"}: + raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") + model_inputs = self._model_inputs(batch) + losses: list[Tensor] = [] + metrics: dict[str, Any] = {} + + if self.config.action_mode == "discrete": + outputs = self._backbone()( + **model_inputs, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + ) + discrete_ce_loss, discrete_z_loss = self._discrete_loss_from_backbone_outputs( + batch, outputs, reduction=reduction + ) + discrete_loss = ( + discrete_ce_loss if discrete_z_loss is None else discrete_ce_loss + discrete_z_loss + ) + losses.append(discrete_loss) + metrics["discrete_ce_loss"] = discrete_ce_loss.detach().float().mean().item() + if discrete_z_loss is not None: + metrics["discrete_z_loss"] = discrete_z_loss.detach().float().mean().item() + + elif self.config.action_mode == "continuous": + flow_loss, _ = self._compute_flow_matching_loss_joint_per_layer( + batch=batch, + model_inputs=model_inputs, + reduction=reduction, + ) + losses.append(flow_loss) + metrics["action_flow_loss"] = flow_loss.detach().float().mean().item() + + else: + flow_loss, hidden_states = self._compute_flow_matching_loss_joint_per_layer( + batch=batch, + model_inputs=model_inputs, + reduction=reduction, + ) + outputs = types.SimpleNamespace(last_hidden_state=hidden_states) + discrete_ce_loss, discrete_z_loss = self._discrete_loss_from_backbone_outputs( + batch, outputs, reduction=reduction + ) + discrete_loss = ( + discrete_ce_loss if discrete_z_loss is None else discrete_ce_loss + discrete_z_loss + ) + losses.append(discrete_loss) + metrics["discrete_ce_loss"] = discrete_ce_loss.detach().float().mean().item() + if discrete_z_loss is not None: + metrics["discrete_z_loss"] = discrete_z_loss.detach().float().mean().item() + losses.append(flow_loss) + metrics["action_flow_loss"] = flow_loss.detach().float().mean().item() + + loss = torch.stack(losses).sum(dim=0) + metrics["loss"] = loss.detach().float().mean().item() + return loss, metrics + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + if "action_mode" in kwargs: + raise TypeError( + "MolmoAct2 predict_action_chunk got unexpected keyword argument 'action_mode'; " + "use 'inference_action_mode'." + ) + model_inputs = self._model_inputs(batch) + inference_action_mode = self._resolve_inference_action_mode(kwargs.get("inference_action_mode")) + num_steps = kwargs.get("num_steps", getattr(self.config, "num_inference_steps", None)) + generator = kwargs.get("generator") + model_dtype = _torch_dtype(self.config.model_dtype) + device = next(self.parameters()).device + batch_size = int(next(iter(model_inputs.values())).shape[0]) + if generator is None: + generator = self._rollout_generator_for_inputs( + batch, + batch_size=batch_size, + device=device, + ) + action_dim = self._output_action_dim(batch) + autocast_context = ( + torch.autocast(device_type=device.type, dtype=model_dtype) + if device.type in {"cuda", "cpu"} and model_dtype in {torch.bfloat16, torch.float16} + else nullcontext() + ) + with autocast_context: + if inference_action_mode == "discrete": + if self._rtc_enabled(): + raise ValueError("RTC is only supported for continuous MolmoAct2 inference.") + actions = self._generate_discrete_actions_from_inputs( + model_inputs=model_inputs, + action_dim=action_dim, + ) + elif self._rtc_enabled(): + actions = self._generate_actions_from_inputs_with_rtc( + model_inputs=model_inputs, + action_dim_is_pad=batch.get("action_dim_is_pad"), + num_steps=num_steps, + generator=generator, + inference_delay=kwargs.get("inference_delay"), + prev_chunk_left_over=kwargs.get("prev_chunk_left_over"), + execution_horizon=kwargs.get("execution_horizon"), + ) + else: + actions = self._backbone().generate_actions_from_inputs( + **model_inputs, + action_dim_is_pad=batch.get("action_dim_is_pad"), + action_horizon=self._generation_action_horizon(), + num_steps=num_steps, + generator=generator, + ) + return actions[:, : self.config.n_action_steps, :action_dim].to(dtype=torch.float32) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + if self._rtc_enabled(): + raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk") + batch_size = int(next(iter(self._model_inputs(batch).values())).shape[0]) + actions: list[Tensor] = [] + for batch_idx in range(batch_size): + queue = self._action_queues[batch_idx] + if not queue: + chunk = self.predict_action_chunk(batch, **kwargs) + for step in torch.unbind(chunk[batch_idx], dim=0): + queue.append(step) + if not queue: + raise RuntimeError("MolmoAct2 produced an empty action chunk.") + actions.append(queue.popleft()) + return torch.stack(actions, dim=0) + + def _get_default_peft_targets(self) -> dict[str, Any]: + target_modules = self._lora_target_modules(prefix=r"model\.model") + return { + "target_modules": target_modules, + "modules_to_save": [], + "r": self.config.lora_rank, + "lora_alpha": self.config.lora_alpha, + "lora_dropout": self.config.lora_dropout, + "bias": self.config.lora_bias, + } + + def _get_inner_peft_targets(self) -> dict[str, Any]: + target_modules = self._lora_target_modules(prefix="model") + return { + "target_modules": target_modules, + "modules_to_save": [], + "r": self.config.lora_rank, + "lora_alpha": self.config.lora_alpha, + "lora_dropout": self.config.lora_dropout, + "bias": self.config.lora_bias, + } + + def _lora_target_modules(self, *, prefix: str) -> str: + vlm_linear_leaves = "w1|w2|w3|wq|wk|wv|wo|att_proj|attn_out|ff_proj|ff_out|patch_embedding" + target_modules = rf"{prefix}\.(transformer|vision_backbone)\.(?:.*\.)?({vlm_linear_leaves})$" + if self.config.enable_lora_action_expert: + action_expert_linear_paths = ( + r"time_embed\.(1|3)|" + r"action_embed|context_k_proj|context_v_proj|" + r"blocks\.\d+\.self_attn\.(qkv|out_proj)|" + r"blocks\.\d+\.cross_attn\.(q_proj|out_proj)|" + r"blocks\.\d+\.mlp\.(up_proj|gate_proj|down_proj)|" + r"blocks\.\d+\.modulation\.linear|" + r"final_layer\.(modulation\.linear|linear)" + ) + target_modules = ( + f"({target_modules}|" + rf"{prefix}\.action_expert\.({action_expert_linear_paths})$)" + ) + return target_modules + + def _build_inner_lora_config(self): + require_package("peft", extra="molmoact2") + from peft import LoraConfig + + return LoraConfig(**self._get_inner_peft_targets()) + + def _apply_lora_adapters(self) -> None: + require_package("peft", extra="molmoact2") + from peft import get_peft_model + + peft_config = self._build_inner_lora_config() + self._validate_peft_config(peft_config) + + for param in self.model.parameters(): + param.requires_grad_(False) + self.model = get_peft_model(self.model, peft_config) + if not self.config.enable_lora_action_expert: + self._unfreeze_action_expert_parameters() + self.train(self.training) + + def _validate_peft_config(self, peft_config) -> None: + del peft_config + if not self.config.checkpoint_path: + raise ValueError("MolmoAct2 LoRA fine-tuning requires `policy.checkpoint_path`.") diff --git a/src/lerobot/policies/molmoact2/processor_molmoact2.py b/src/lerobot/policies/molmoact2/processor_molmoact2.py new file mode 100644 index 000000000..d4f2cf2c4 --- /dev/null +++ b/src/lerobot/policies/molmoact2/processor_molmoact2.py @@ -0,0 +1,883 @@ +from __future__ import annotations + +import json +import os +import re +from contextlib import suppress +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from huggingface_hub import snapshot_download +from torch import Tensor + +from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.types import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) +from lerobot.utils.import_utils import require_package + +from .configuration_molmoact2 import MolmoAct2Config, infer_molmoact2_max_sequence_length + +ACTION_OUTPUT_TOKEN = "" # nosec B105 +ACTION_START_TOKEN = "" # nosec B105 +ACTION_END_TOKEN = "" # nosec B105 +ACTION_TOKEN_PREFIX = " str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_checkpoint_location( + checkpoint_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + checkpoint_path = str(checkpoint_path or "").strip() + if not checkpoint_path: + raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") + local_path = Path(checkpoint_path).expanduser() + if local_path.exists(): + return str(local_path) + return snapshot_download( + repo_id=checkpoint_path, + repo_type="model", + revision=revision, + force_download=force_download, + token=_hf_token(), + ) + + +def _load_hf_norm_stats_for_tag( + checkpoint_path: str, + *, + revision: str | None, + force_download: bool, + norm_tag: str | None, +) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]: + norm_tag = str(norm_tag or "").strip() + if not norm_tag: + raise ValueError("MolmoAct2 HF checkpoint inference requires `policy.norm_tag` for normalization.") + + checkpoint_location = Path( + _resolve_checkpoint_location( + checkpoint_path, + revision=revision, + force_download=force_download, + ) + ) + config_path = checkpoint_location / "config.json" + norm_stats_filename = "norm_stats.json" + if config_path.exists(): + with suppress(OSError, json.JSONDecodeError): + norm_stats_filename = str( + json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename + ) + + stats_path = checkpoint_location / norm_stats_filename + if not stats_path.exists(): + raise FileNotFoundError( + f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." + ) + payload = json.loads(stats_path.read_text()) + metadata_by_tag = payload.get("metadata_by_tag") + if not isinstance(metadata_by_tag, dict): + raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") + metadata = metadata_by_tag.get(norm_tag) + if metadata is None: + available = sorted(str(tag) for tag in metadata_by_tag) + raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") + if not isinstance(metadata, dict): + raise ValueError(f"MolmoAct2 norm_tag={norm_tag!r} metadata must be a mapping.") + + def numeric_stats(raw_stats: dict[str, Any]) -> dict[str, Any]: + stats: dict[str, Any] = {} + for key, value in raw_stats.items(): + if key == "names": + continue + if isinstance(value, (list, tuple)) and any(isinstance(item, str) for item in value): + continue + stats[key] = deepcopy(value) + return stats + + action_stats = metadata.get("action_stats") + state_stats = metadata.get("state_stats") + if not isinstance(action_stats, dict) or not isinstance(state_stats, dict): + raise ValueError(f"MolmoAct2 norm_tag={norm_tag!r} must define action_stats and state_stats.") + return {ACTION: numeric_stats(action_stats), OBS_STATE: numeric_stats(state_stats)}, metadata + + +def _to_numpy(value: Any) -> np.ndarray: + if isinstance(value, np.ndarray): + return value + if torch.is_tensor(value): + return value.detach().cpu().numpy() + return np.asarray(value) + + +def _normalize_image(value: Any) -> np.ndarray: + arr = _to_numpy(value) + while arr.ndim > 3 and int(arr.shape[0]) == 1: + arr = arr[0] + if arr.ndim == 2: + arr = np.stack([arr] * 3, axis=-1) + if arr.ndim == 3 and arr.shape[0] in {1, 3, 4} and arr.shape[-1] not in {1, 3, 4}: + arr = np.moveaxis(arr, 0, -1) + if arr.ndim == 3 and arr.shape[-1] == 1: + arr = np.repeat(arr, 3, axis=-1) + if arr.ndim != 3 or arr.shape[-1] not in {3, 4}: + raise ValueError(f"Unsupported image shape for MolmoAct2: {arr.shape}.") + if arr.shape[-1] == 4: + arr = arr[..., :3] + if arr.dtype in (np.float16, np.float32, np.float64): + if arr.size > 0 and float(np.nanmax(arr)) <= 1.0: + arr = arr * 255.0 + arr = np.clip(arr, 0, 255).astype(np.uint8) + elif arr.dtype != np.uint8: + arr = np.clip(arr, 0, 255).astype(np.uint8) + return arr + + +def _normalize_question_text(text: str) -> str: + normalized = re.sub(r"\s+", " ", str(text or "")).strip() + if not normalized: + return "" + previous = None + while normalized and normalized != previous: + previous = normalized + normalized = normalized.strip().strip(_QUESTION_SURROUNDING_DELIMITERS).strip() + for pattern in _QUESTION_PREFIX_PATTERNS: + normalized = pattern.sub("", normalized, count=1).strip() + normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() + normalized = normalized.rstrip(_QUESTION_TRAILING_CLOSERS).rstrip() + normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() + chunks = [chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip()] + if len(chunks) > 1: + normalized = "; ".join(chunks) + return normalized.lower() + + +def _wrap_setup_text(setup_type: str, add_setup_tokens: bool) -> str: + setup_type = str(setup_type or "") + if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith(SETUP_END_TOKEN): + return setup_type + if not setup_type or not add_setup_tokens: + return setup_type + return f"{SETUP_START_TOKEN}{setup_type}{SETUP_END_TOKEN}" + + +def _wrap_control_text(control_mode: str, add_control_tokens: bool) -> str: + control_mode = str(control_mode or "") + if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith(CONTROL_END_TOKEN): + return control_mode + if not control_mode or not add_control_tokens: + return control_mode + return f"{CONTROL_START_TOKEN}{control_mode}{CONTROL_END_TOKEN}" + + +def _build_discrete_state_string(state: np.ndarray, num_state_tokens: int) -> str: + if num_state_tokens <= 0: + raise ValueError(f"num_state_tokens must be > 0, got {num_state_tokens}.") + arr = np.asarray(state, dtype=np.float32) + arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0) + arr = np.clip(arr, -1.0, 1.0) + scaled = (arr + 1.0) / 2.0 * float(num_state_tokens - 1) + token_ids = np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1).reshape(-1) + return f"{STATE_START_TOKEN}{''.join(f'{STATE_TOKEN_PREFIX}{int(token_id)}>' for token_id in token_ids)}{STATE_END_TOKEN}" + + +def _build_robot_text( + *, + task: str, + discrete_state_string: str, + setup_type: str, + control_mode: str, + add_setup_tokens: bool, + add_control_tokens: bool, + num_images: int, +) -> str: + setup_text = _wrap_setup_text(setup_type, add_setup_tokens=add_setup_tokens) + control_text = _wrap_control_text(control_mode, add_control_tokens=add_control_tokens) + state_clause = ( + f" The current state of the robot is {discrete_state_string}." if discrete_state_string else "" + ) + prompt = ( + f"The task is to {task}. The setup is {setup_text}.{state_clause} " + f"The expected control mode is {control_text}. Given these, what action should the robot take to complete the task?" + ) + if num_images <= 0: + image_prefix = "" + elif num_images == 1: + image_prefix = "<|image|>" + else: + image_prefix = "".join(f"Image {idx + 1}<|image|>" for idx in range(num_images)) + return f"{image_prefix}<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{ACTION_OUTPUT_TOKEN}" + + +def _as_text_list(value: Any, batch_size: int) -> list[str]: + if value is None: + return [""] * batch_size + if isinstance(value, str): + return [value] * batch_size + if torch.is_tensor(value): + if value.ndim == 0: + return [str(value.item())] * batch_size + flat = value.detach().cpu().reshape(-1).tolist() + texts = [str(item) for item in flat] + elif isinstance(value, np.ndarray): + if value.ndim == 0: + return [str(value.item())] * batch_size + texts = [str(item) for item in value.reshape(-1).tolist()] + elif isinstance(value, (list, tuple)): + texts = [str(item) for item in value] + else: + texts = [str(value)] + if len(texts) == batch_size: + return texts + if len(texts) == 1: + return texts * batch_size + raise ValueError(f"Expected {batch_size} task strings, got {len(texts)}.") + + +def _tokenize_discrete_action(action: np.ndarray, processor: Any) -> list[int]: + arr = np.asarray(action, dtype=np.float32) + if arr.ndim == 2: + arr = arr[None, :, :] + elif arr.ndim == 1: + arr = arr[None, None, :] + tokens_out = processor(arr) + if isinstance(tokens_out, dict): + tokens_out = tokens_out.get("input_ids", next(iter(tokens_out.values()))) + if isinstance(tokens_out, np.ndarray): + tokens_out = tokens_out.tolist() + if torch.is_tensor(tokens_out): + tokens_out = tokens_out.detach().cpu().tolist() + if not isinstance(tokens_out, list): + raise TypeError(f"Unexpected discrete action tokenizer output type: {type(tokens_out)}") + if tokens_out and isinstance(tokens_out[0], (list, tuple, np.ndarray)): + tokens_out = tokens_out[0] + return [int(token_id) for token_id in tokens_out] + + +def _build_discrete_action_string(action: np.ndarray, processor: Any) -> str: + token_ids = _tokenize_discrete_action(action, processor) + pieces = "".join(f"{ACTION_TOKEN_PREFIX}{int(token_id)}>" for token_id in token_ids) + return f"{ACTION_START_TOKEN}{pieces}{ACTION_END_TOKEN}" + + +def _single_token_id(tokenizer: Any, token: str) -> int: + token_ids = tokenizer.encode(token, add_special_tokens=False) + if len(token_ids) != 1: + raise ValueError(f"MolmoAct2 token {token!r} must encode to one token, got {token_ids}.") + return int(token_ids[0]) + + +def _flatten_feature_names(raw_names: Any) -> list[str] | None: + if raw_names is None: + return None + if isinstance(raw_names, dict): + names: list[str] = [] + for value in raw_names.values(): + if isinstance(value, (list, tuple)): + names.extend(str(item) for item in value) + elif value is not None: + names.append(str(value)) + return names or None + if isinstance(raw_names, (list, tuple)): + names = [str(item) for item in raw_names] + return names or None + return [str(raw_names)] + + +def _feature_dim(stats: dict[str, Any] | None) -> int | None: + if not isinstance(stats, dict): + return None + for key in ("mean", "std", "min", "max", "q01", "q99", "q10", "q90", "mask"): + value = stats.get(key) + if value is None: + continue + if torch.is_tensor(value): + return int(value.shape[-1]) if value.ndim > 0 else None + arr = np.asarray(value) + return int(arr.shape[-1]) if arr.ndim > 0 else None + return None + + +def _feature_names_from_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None: + if dataset_meta is None: + return None + + root = getattr(dataset_meta, "root", None) + candidate_roots = [] + if root is not None: + repo_id = str(getattr(dataset_meta, "repo_id", "") or "").strip() + if repo_id: + candidate_roots.append(Path(root) / repo_id) + candidate_roots.append(Path(root)) + for candidate_root in candidate_roots: + info_path = candidate_root / "meta" / "info.json" + if info_path.exists(): + try: + with info_path.open("r", encoding="utf-8") as f: + info = json.load(f) + names = _flatten_feature_names((info.get("features") or {}).get(feature_key, {}).get("names")) + if names: + return names + except (OSError, json.JSONDecodeError, AttributeError): + pass + + for container in ( + getattr(getattr(dataset_meta, "info", None), "features", None), + getattr(dataset_meta, "features", None), + ): + if not isinstance(container, dict): + continue + feature = container.get(feature_key) + if not isinstance(feature, dict): + continue + names = _flatten_feature_names(feature.get("names")) + if names: + return names + return None + + +def _add_gripper_masks_to_stats( + dataset_stats: dict[str, dict[str, Any]] | None, + dataset_meta: Any | None, + *, + normalize_gripper: bool, + dataset_feature_names: dict[str, Any] | None = None, +) -> dict[str, dict[str, Any]] | None: + if not dataset_stats: + return dataset_stats + + stats = deepcopy(dataset_stats) + for key in (ACTION, OBS_STATE): + feature_stats = stats.get(key) + if not isinstance(feature_stats, dict): + continue + dim = _feature_dim(feature_stats) + if dim is None: + continue + + if normalize_gripper: + feature_stats["mask"] = [True] * dim + continue + + names = _flatten_feature_names((dataset_feature_names or {}).get(key)) + if names is None: + names = _feature_names_from_meta(dataset_meta, key) + if names is None: + names = _flatten_feature_names(feature_stats.get("names")) + if names is None: + continue + if len(names) != dim: + continue + feature_stats["mask"] = ["gripper" not in name.lower() for name in names] + return stats + + +class _MolmoAct2MaskedNormalizationMixin: + def _apply_transform( + self, tensor: Tensor, key: str, feature_type: Any, *, inverse: bool = False + ) -> Tensor: + transformed = super()._apply_transform(tensor, key, feature_type, inverse=inverse) + stats = getattr(self, "_tensor_stats", {}).get(key, {}) + mask = stats.get("mask") if isinstance(stats, dict) else None + if mask is None: + return transformed + mask = mask.to(device=tensor.device, dtype=torch.bool) + if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]: + return transformed + while mask.ndim < tensor.ndim: + mask = mask.unsqueeze(0) + return torch.where(mask, transformed, tensor) + + +@ProcessorStepRegistry.register(name="molmoact2_masked_normalizer") +@dataclass +class MolmoAct2MaskedNormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixin, NormalizerProcessorStep): + pass + + +@ProcessorStepRegistry.register(name="molmoact2_masked_unnormalizer") +@dataclass +class MolmoAct2MaskedUnnormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixin, UnnormalizerProcessorStep): + pass + + +@ProcessorStepRegistry.register(name="molmoact2_clamp_normalized") +@dataclass +class MolmoAct2ClampNormalizedProcessorStep(ProcessorStep): + """Clamp q01/q99-normalized state and action to the range used by the old trainer.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict) and OBS_STATE in observation: + observation = observation.copy() + observation[OBS_STATE] = torch.as_tensor(observation[OBS_STATE]).clamp(-1.0, 1.0) + transition[TransitionKey.OBSERVATION] = observation + action = transition.get(TransitionKey.ACTION) + if action is not None: + transition[TransitionKey.ACTION] = torch.as_tensor(action).clamp(-1.0, 1.0) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register(name="molmoact2_pack_inputs") +@dataclass +class MolmoAct2PackInputsProcessorStep(ProcessorStep): + checkpoint_path: str + checkpoint_revision: str | None = None + checkpoint_force_download: bool = False + trust_remote_code: bool = True + action_mode: str = "both" + discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer" + image_keys: list[str] = field(default_factory=list) + setup_type: str = "" + control_mode: str = "" + normalize_language: bool = True + add_setup_tokens: bool = True + add_control_tokens: bool = True + num_state_tokens: int = 256 + max_sequence_length: int | None = None + chunk_size: int = 30 + max_action_dim: int = 32 + env_action_dim: int | None = None + + def __post_init__(self) -> None: + require_package("transformers", extra="molmoact2") + from transformers import AutoProcessor + + checkpoint_location = _resolve_checkpoint_location( + self.checkpoint_path, + revision=self.checkpoint_revision, + force_download=bool(self.checkpoint_force_download), + ) + self.processor = AutoProcessor.from_pretrained( + checkpoint_location, + trust_remote_code=self.trust_remote_code, + use_fast=False, + token=_hf_token(), + ) + self.action_processor = None + if self.action_mode in {"discrete", "both"}: + self.action_processor = AutoProcessor.from_pretrained( + self.discrete_action_tokenizer, + trust_remote_code=self.trust_remote_code, + token=_hf_token(), + ) + self._action_start_id = _single_token_id(self.processor.tokenizer, ACTION_START_TOKEN) + self._action_end_id = _single_token_id(self.processor.tokenizer, ACTION_END_TOKEN) + self._eos_token = self.processor.tokenizer.eos_token or "" + self._eos_token_id = self.processor.tokenizer.eos_token_id + + def get_config(self) -> dict[str, Any]: + return { + "checkpoint_path": self.checkpoint_path, + "checkpoint_revision": self.checkpoint_revision, + "checkpoint_force_download": self.checkpoint_force_download, + "trust_remote_code": self.trust_remote_code, + "action_mode": self.action_mode, + "discrete_action_tokenizer": self.discrete_action_tokenizer, + "image_keys": list(self.image_keys), + "setup_type": self.setup_type, + "control_mode": self.control_mode, + "normalize_language": self.normalize_language, + "add_setup_tokens": self.add_setup_tokens, + "add_control_tokens": self.add_control_tokens, + "num_state_tokens": self.num_state_tokens, + "max_sequence_length": self.max_sequence_length, + "chunk_size": self.chunk_size, + "max_action_dim": self.max_action_dim, + "env_action_dim": self.env_action_dim, + } + + def _resolve_max_sequence_length( + self, + *, + num_images: int, + state_dim: int, + action_dim: int, + action_horizon: int, + include_discrete_action: bool, + ) -> int: + if self.max_sequence_length is not None: + return int(self.max_sequence_length) + return infer_molmoact2_max_sequence_length( + num_images=num_images, + state_dim=state_dim, + action_dim=action_dim, + action_horizon=action_horizon, + include_discrete_action=include_discrete_action, + ) + + def _batch_size(self, observation: dict[str, Any], action: Tensor | None) -> int: + if action is not None: + return int(action.shape[0]) + state = observation.get(OBS_STATE) + if torch.is_tensor(state) or isinstance(state, np.ndarray): + return int(state.shape[0]) if getattr(state, "ndim", 0) > 1 else 1 + for key in self._resolve_image_keys(observation): + value = observation[key] + if torch.is_tensor(value) or isinstance(value, np.ndarray): + return int(value.shape[0]) if getattr(value, "ndim", 0) == 4 else 1 + return 1 + + def _resolve_image_keys(self, observation: dict[str, Any]) -> list[str]: + if self.image_keys: + missing = [key for key in self.image_keys if key not in observation] + if missing: + raise ValueError(f"MolmoAct2 image_keys missing from observation: {missing}.") + return list(self.image_keys) + keys = [key for key in observation if str(key).startswith(f"{OBS_IMAGES}.")] + if not keys: + keys = [key for key in observation if str(key).startswith("observation.image")] + if not keys: + raise ValueError("MolmoAct2 requires at least one image observation.") + return sorted(keys) + + def _extract_images(self, observation: dict[str, Any], batch_size: int) -> list[list[np.ndarray]]: + images_by_example: list[list[np.ndarray]] = [[] for _ in range(batch_size)] + for key in self._resolve_image_keys(observation): + value = observation[key] + for batch_idx in range(batch_size): + item = value + if (torch.is_tensor(value) or isinstance(value, np.ndarray)) and getattr( + value, "ndim", 0 + ) >= 4: + item = value[batch_idx] + images_by_example[batch_idx].append(_normalize_image(item)) + return images_by_example + + def _extract_state(self, observation: dict[str, Any], batch_size: int) -> Tensor: + if OBS_STATE not in observation: + raise ValueError("MolmoAct2 requires observation.state for discrete state prompting.") + state = torch.as_tensor(observation[OBS_STATE], dtype=torch.float32) + if state.ndim == 1: + state = state.unsqueeze(0) + if int(state.shape[0]) != batch_size: + raise ValueError(f"State batch size {state.shape[0]} does not match batch size {batch_size}.") + return state + + def _pad_action(self, action: Tensor, action_is_pad: Any | None) -> tuple[Tensor, Tensor, Tensor]: + if action.ndim == 2: + action = action.unsqueeze(1) + if action.ndim != 3: + raise ValueError(f"MolmoAct2 expected action shape [B, T, D], got {tuple(action.shape)}.") + if action.shape[-1] > self.max_action_dim: + raise ValueError( + f"Action dim {action.shape[-1]} exceeds MolmoAct2 max_action_dim={self.max_action_dim}." + ) + padded = torch.zeros( + (*action.shape[:-1], self.max_action_dim), + device=action.device, + dtype=torch.float32, + ) + padded[..., : action.shape[-1]] = action.to(dtype=torch.float32) + action_dim_is_pad = torch.ones( + (action.shape[0], self.max_action_dim), device=action.device, dtype=torch.bool + ) + action_dim_is_pad[:, : action.shape[-1]] = False + if action_is_pad is None: + action_horizon_is_pad = torch.zeros(action.shape[:2], device=action.device, dtype=torch.bool) + else: + action_horizon_is_pad = torch.as_tensor(action_is_pad, device=action.device, dtype=torch.bool) + if action_horizon_is_pad.ndim == 1: + action_horizon_is_pad = action_horizon_is_pad.unsqueeze(0) + if tuple(action_horizon_is_pad.shape) != tuple(action.shape[:2]): + raise ValueError( + "action_is_pad must match action horizon shape: " + f"got {tuple(action_horizon_is_pad.shape)} for action {tuple(action.shape)}." + ) + return padded, action_horizon_is_pad, action_dim_is_pad + + def _build_labels(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + labels = torch.full_like(input_ids, -100) + for batch_idx in range(input_ids.shape[0]): + valid = attention_mask[batch_idx].to(dtype=torch.bool) + row = input_ids[batch_idx] + starts = (row == self._action_start_id).nonzero(as_tuple=False).flatten().tolist() + ends = (row == self._action_end_id).nonzero(as_tuple=False).flatten().tolist() + end_ptr = 0 + for start in starts: + while end_ptr < len(ends) and ends[end_ptr] < start: + end_ptr += 1 + if end_ptr >= len(ends): + raise ValueError( + "Found without matching in MolmoAct2 labels." + ) + end = int(ends[end_ptr]) + label_end = end + 1 + if ( + self._eos_token_id is not None + and label_end < int(row.shape[0]) + and int(row[label_end]) == int(self._eos_token_id) + ): + label_end += 1 + labels[batch_idx, start:label_end] = row[start:label_end] + end_ptr += 1 + if not starts: + raise ValueError("No discrete action span found in MolmoAct2 training text.") + labels[batch_idx] = torch.where( + valid, labels[batch_idx], torch.full_like(labels[batch_idx], -100) + ) + return labels + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + observation = transition.get(TransitionKey.OBSERVATION) or {} + if not isinstance(observation, dict): + raise ValueError("MolmoAct2 expected an observation dictionary.") + complementary = dict(transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) + + raw_action = transition.get(TransitionKey.ACTION) + action = torch.as_tensor(raw_action, dtype=torch.float32) if raw_action is not None else None + batch_size = self._batch_size(observation, action) + state = self._extract_state(observation, batch_size) + images_by_example = self._extract_images(observation, batch_size) + + task_source = complementary.get("task") + if task_source is None: + task_source = observation.get("task") + if task_source is None: + task_source = observation.get("observation.language") + if task_source is None: + task_source = complementary.get("language_instruction") + tasks = _as_text_list(task_source, batch_size) + if self.normalize_language: + tasks = [_normalize_question_text(task) for task in tasks] + complementary["task"] = tasks + + action_padded = None + action_horizon_is_pad = None + action_dim_is_pad = torch.ones((batch_size, self.max_action_dim), dtype=torch.bool) + real_action_dim = int(self.env_action_dim or 0) + if action is not None: + action_is_pad = complementary.get("action_is_pad") + if action_is_pad is None: + action_is_pad = complementary.get("action_horizon_is_pad") + action_padded, action_horizon_is_pad, action_dim_is_pad = self._pad_action(action, action_is_pad) + real_action_dim = int(action.shape[-1]) + elif real_action_dim > 0: + action_dim_is_pad[:, :real_action_dim] = False + + prompt_texts: list[str] = [] + full_texts: list[str] = [] + flat_images: list[np.ndarray] = [] + state_np = state.detach().cpu().numpy() + build_action_labels = action is not None and self.action_mode in {"discrete", "both"} + for batch_idx in range(batch_size): + images = images_by_example[batch_idx] + flat_images.extend(images) + discrete_state = _build_discrete_state_string(state_np[batch_idx], self.num_state_tokens) + prompt = _build_robot_text( + task=tasks[batch_idx], + discrete_state_string=discrete_state, + setup_type=self.setup_type, + control_mode=self.control_mode, + add_setup_tokens=self.add_setup_tokens, + add_control_tokens=self.add_control_tokens, + num_images=len(images), + ) + prompt_texts.append(prompt) + if build_action_labels: + if self.action_processor is None: + raise ValueError("Discrete MolmoAct2 training requires an action tokenizer.") + answer = _build_discrete_action_string( + action[batch_idx].detach().cpu().numpy(), self.action_processor + ) + full_texts.append(f"{prompt}{answer}{self._eos_token}") + else: + full_texts.append(prompt) + + text = full_texts if build_action_labels else prompt_texts + inputs = self.processor(text=text, images=flat_images, return_tensors="pt", padding=True) + if action is None: + action_horizon = self.chunk_size + elif action.ndim == 2: + action_horizon = 1 + else: + action_horizon = int(action.shape[1]) + max_sequence_length = self._resolve_max_sequence_length( + num_images=max((len(images) for images in images_by_example), default=0), + state_dim=int(state.shape[-1]), + action_dim=max(real_action_dim, 1), + action_horizon=action_horizon, + include_discrete_action=build_action_labels, + ) + if int(inputs["input_ids"].shape[1]) > max_sequence_length: + raise ValueError( + f"MolmoAct2 sequence length {int(inputs['input_ids'].shape[1])} exceeds " + f"max_sequence_length={max_sequence_length}." + ) + + if build_action_labels: + inputs["labels"] = self._build_labels(inputs["input_ids"], inputs["attention_mask"]) + + complementary.update(dict(inputs)) + complementary["action_dim_is_pad"] = action_dim_is_pad + if action_horizon_is_pad is not None: + complementary["action_horizon_is_pad"] = action_horizon_is_pad + + if action_padded is not None: + transition[TransitionKey.ACTION] = action_padded + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register(name="molmoact2_clamp_action") +@dataclass +class MolmoAct2ClampActionProcessorStep(ProcessorStep): + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + action = transition.get(TransitionKey.ACTION) + if action is not None: + transition[TransitionKey.ACTION] = torch.as_tensor(action).clamp(-1.0, 1.0) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +def make_molmoact2_pre_post_processors( + config: MolmoAct2Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta: Any | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + env_action_dim = None + if config.output_features and ACTION in config.output_features: + env_action_dim = int(config.output_features[ACTION].shape[0]) + + hf_metadata: dict[str, Any] = {} + if dataset_stats is None and str(config.norm_tag or "").strip(): + dataset_stats, hf_metadata = _load_hf_norm_stats_for_tag( + config.checkpoint_path, + revision=config.checkpoint_revision, + force_download=bool(config.checkpoint_force_download), + norm_tag=config.norm_tag, + ) + + image_keys = list(config.image_keys) + if not image_keys and isinstance(hf_metadata.get("camera_keys"), list): + image_keys = [str(key) for key in hf_metadata["camera_keys"]] + setup_type = config.setup_type or str(hf_metadata.get("setup_type") or "") + control_mode = config.control_mode or str(hf_metadata.get("control_mode") or "") + chunk_size = int(hf_metadata.get("action_horizon") or config.chunk_size) + + masked_dataset_stats = _add_gripper_masks_to_stats( + dataset_stats, + dataset_meta, + normalize_gripper=config.normalize_gripper, + dataset_feature_names=config.dataset_feature_names, + ) + + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + MolmoAct2MaskedNormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=masked_dataset_stats, + ), + MolmoAct2ClampNormalizedProcessorStep(), + MolmoAct2PackInputsProcessorStep( + checkpoint_path=config.checkpoint_path, + checkpoint_revision=config.checkpoint_revision, + checkpoint_force_download=config.checkpoint_force_download, + trust_remote_code=config.trust_remote_code, + action_mode=config.action_mode, + discrete_action_tokenizer=config.discrete_action_tokenizer, + image_keys=image_keys, + setup_type=setup_type, + control_mode=control_mode, + normalize_language=config.normalize_language, + add_setup_tokens=config.add_setup_tokens, + add_control_tokens=config.add_control_tokens, + num_state_tokens=config.num_state_tokens, + max_sequence_length=config.max_sequence_length, + chunk_size=chunk_size, + max_action_dim=config.expected_max_action_dim, + env_action_dim=env_action_dim, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + MolmoAct2ClampActionProcessorStep(), + MolmoAct2MaskedUnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=masked_dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/tests/policies/molmoact2/test_molmoact2.py b/tests/policies/molmoact2/test_molmoact2.py new file mode 100644 index 000000000..04c704dd0 --- /dev/null +++ b/tests/policies/molmoact2/test_molmoact2.py @@ -0,0 +1,1182 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +import torch.nn.functional as F # noqa: N812 + +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies import get_policy_class, make_policy_config +from lerobot.policies.molmoact2 import modeling_molmoact2 as molmoact2_modeling +from lerobot.policies.molmoact2.configuration_molmoact2 import ( + MolmoAct2Config, + MolmoAct2CosineDecayWithWarmupSchedulerConfig, + infer_molmoact2_max_sequence_length, +) +from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy +from lerobot.policies.molmoact2.processor_molmoact2 import ( + MolmoAct2MaskedNormalizerProcessorStep, + MolmoAct2MaskedUnnormalizerProcessorStep, + MolmoAct2PackInputsProcessorStep, + _add_gripper_masks_to_stats, + _build_discrete_state_string, + _normalize_question_text, +) +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.types import TransitionKey +from lerobot.utils.constants import ACTION, OBS_STATE + + +def test_molmoact2_policy_registration(): + cfg = make_policy_config("molmoact2", checkpoint_path="/tmp/not-a-real-checkpoint") + + assert cfg.type == "molmoact2" + assert cfg.action_mode == "both" + assert cfg.normalize_gripper is False + assert cfg.enable_knowledge_insulation is False + assert cfg.freeze_embedding is True + assert cfg.per_episode_seed is False + assert cfg.eval_seed is None + assert cfg.normalize_language is True + assert cfg.get_scheduler_preset().num_decay_steps is None + assert cfg.action_delta_indices == list(range(cfg.chunk_size)) + assert get_policy_class("molmoact2") is MolmoAct2Policy + + +def test_molmoact2_scheduler_decay_steps_auto_match_training_steps(): + param = torch.nn.Parameter(torch.ones(())) + optimizer = torch.optim.AdamW([param], lr=0.001) + config = MolmoAct2CosineDecayWithWarmupSchedulerConfig( + peak_lr=0.01, + decay_lr=0.001, + num_warmup_steps=10, + num_decay_steps=None, + ) + + scheduler = config.build(optimizer, num_training_steps=100) + for _ in range(100): + optimizer.step() + scheduler.step() + + assert scheduler.get_last_lr() == pytest.approx([0.0001]) + + +def test_molmoact2_rollout_generator_uses_eval_seed_per_task(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config(per_episode_seed=True, eval_seed=1000) + policy._rollout_action_generator = None + policy._rollout_task_key = None + policy._rollout_index_for_task = -1 + + policy.reset() + first = policy._rollout_generator_for_inputs( + {"task": ["pick", "pick", "pick"]}, + batch_size=3, + device=torch.device("cpu"), + ) + expected_first = torch.Generator().manual_seed( + MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3) + ) + assert torch.allclose(torch.rand(4, generator=first), torch.rand(4, generator=expected_first)) + + policy.reset() + second = policy._rollout_generator_for_inputs( + {"task": ["pick", "pick", "pick"]}, + batch_size=3, + device=torch.device("cpu"), + ) + expected_second = torch.Generator().manual_seed( + MolmoAct2Policy._combine_rollout_seeds(first_seed=1003, batch_size=3) + ) + assert torch.allclose(torch.rand(4, generator=second), torch.rand(4, generator=expected_second)) + + policy.reset() + new_task = policy._rollout_generator_for_inputs( + {"task": ["place", "place", "place"]}, + batch_size=3, + device=torch.device("cpu"), + ) + expected_new_task = torch.Generator().manual_seed( + MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3) + ) + assert torch.allclose(torch.rand(4, generator=new_task), torch.rand(4, generator=expected_new_task)) + + +def test_molmoact2_gripper_mask_uses_feature_names(tmp_path): + meta_dir = tmp_path / "meta" + meta_dir.mkdir() + (meta_dir / "info.json").write_text( + json.dumps( + { + "features": { + ACTION: {"names": {"motors": ["x", "gripper"]}}, + OBS_STATE: {"names": {"motors": ["joint", "gripper"]}}, + } + } + ), + encoding="utf-8", + ) + dataset_meta = SimpleNamespace(root=tmp_path) + stats = { + ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}, + OBS_STATE: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}, + } + + masked_stats = _add_gripper_masks_to_stats(stats, dataset_meta, normalize_gripper=False) + + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, False] + assert masked_stats[OBS_STATE]["mask"] == [True, False] + + features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)), + } + norm_map = { + FeatureType.ACTION: NormalizationMode.QUANTILES, + FeatureType.STATE: NormalizationMode.QUANTILES, + } + transition = { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])}, + TransitionKey.ACTION: torch.tensor([[5.0, 7.0]]), + } + normalizer = MolmoAct2MaskedNormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=masked_stats, + ) + normalized = normalizer(transition) + + assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 7.0]])) + assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, 7.0]])) + + unnormalizer = MolmoAct2MaskedUnnormalizerProcessorStep( + features={ACTION: features[ACTION]}, + norm_map=norm_map, + stats=masked_stats, + ) + unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, 7.0]])}) + + assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, 7.0]])) + + +def test_molmoact2_normalize_gripper_true_keeps_all_dims_normalized(tmp_path): + meta_dir = tmp_path / "meta" + meta_dir.mkdir() + (meta_dir / "info.json").write_text( + json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}), + encoding="utf-8", + ) + stats = {ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}} + + masked_stats = _add_gripper_masks_to_stats( + stats, + SimpleNamespace(root=tmp_path), + normalize_gripper=True, + ) + + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, True] + + +def test_molmoact2_uses_supplied_stats_with_repo_scoped_names(tmp_path): + repo_root = tmp_path / "test-org" / "libero" + (repo_root / "meta").mkdir(parents=True) + (repo_root / "meta" / "info.json").write_text( + json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}), + encoding="utf-8", + ) + base_stats = {ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}} + + masked_stats = _add_gripper_masks_to_stats( + base_stats, + SimpleNamespace(root=tmp_path, repo_id="test-org/libero"), + normalize_gripper=False, + ) + + assert masked_stats is not None + assert masked_stats[ACTION]["q01"] == [0.0, 0.0] + assert masked_stats[ACTION]["mask"] == [True, False] + + +def test_molmoact2_uses_config_feature_names_without_dataset_meta(): + base_stats = {ACTION: {"q01": [0.0, 0.0], "q99": [10.0, 10.0]}} + + masked_stats = _add_gripper_masks_to_stats( + base_stats, + None, + normalize_gripper=False, + dataset_feature_names={ACTION: ["x", "gripper"]}, + ) + + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, False] + + +def test_enable_lora_vlm_builds_policy_local_peft_config(): + pytest.importorskip("peft") + policy_cfg = MolmoAct2Config( + checkpoint_path="/tmp/not-a-real-checkpoint", + device="cpu", + enable_lora_vlm=True, + lora_rank=64, + push_to_hub=False, + ) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = policy_cfg + + peft_config = policy._build_inner_lora_config() + + assert peft_config.r == 64 + assert peft_config.target_modules == policy._get_inner_peft_targets()["target_modules"] + assert not policy_cfg.use_peft + + +def test_cuda_graph_managers_are_inference_only(): + class DummyManager: + def __init__(self): + self.enabled = None + + def set_enabled(self, enabled): + self.enabled = enabled + + class DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.action_cuda_graph_manager = DummyManager() + + def _require_action_expert(self): + return torch.nn.Linear(1, 1) + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = DummyBackbone() + self.depth_decode_cuda_graph_manager = DummyManager() + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(train_action_expert_only=False, enable_inference_cuda_graph=True) + policy.model = DummyModel() + + policy.train() + assert policy.model.model.action_cuda_graph_manager.enabled is False + assert policy.model.depth_decode_cuda_graph_manager.enabled is False + + policy.eval() + assert policy.model.model.action_cuda_graph_manager.enabled is True + assert policy.model.depth_decode_cuda_graph_manager.enabled is True + + policy.config.enable_inference_cuda_graph = False + policy.eval() + assert policy.model.model.action_cuda_graph_manager.enabled is False + assert policy.model.depth_decode_cuda_graph_manager.enabled is False + + +def test_lora_action_expert_target_is_opt_in(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + lora_rank=64, + lora_alpha=16, + lora_dropout=0.05, + lora_bias="none", + enable_lora_action_expert=False, + ) + + targets = policy._get_default_peft_targets()["target_modules"] + + assert "transformer|vision_backbone" in targets + assert "action_expert" not in targets + + policy.config.enable_lora_action_expert = True + targets = policy._get_default_peft_targets()["target_modules"] + + assert "action_expert" in targets + assert "state_encoder" not in targets + assert "state_norm" not in targets + assert "kv_proj" not in targets + + +def test_enable_lora_vlm_wraps_loaded_hf_model_locally(): + pytest.importorskip("peft") + + class DummyInnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.transformer = torch.nn.Module() + self.transformer.wq = torch.nn.Linear(2, 2) + self.action_expert = torch.nn.Module() + self.action_expert.action_embed = torch.nn.Linear(2, 2) + + class DummyHFModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = {} + self.model = DummyInnerModel() + + def forward(self, x): + return self.model.transformer.wq(x) + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + checkpoint_path="/tmp/base", + lora_rank=2, + lora_alpha=4, + lora_dropout=0.0, + lora_bias="none", + enable_lora_action_expert=False, + train_action_expert_only=False, + enable_inference_cuda_graph=False, + ) + policy.model = DummyHFModel() + + policy._apply_lora_adapters() + + assert policy._backbone() is policy.model.base_model.model.model + trainable = [name for name, param in policy.named_parameters() if param.requires_grad] + assert trainable + assert any("lora_" in name for name in trainable) + assert any("action_expert.action_embed" in name and "lora_" not in name for name in trainable) + assert policy.model(torch.ones(1, 2)).shape == (1, 2) + + +def test_lora_vlm_unfreezes_action_expert_base_weights(): + class DummyInnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.transformer = torch.nn.Module() + self.transformer.wq = torch.nn.Linear(2, 2) + self.action_expert = torch.nn.Module() + self.action_expert.action_embed = torch.nn.Linear(2, 2) + + class DummyHFModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = DummyInnerModel() + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.model = DummyHFModel() + + for param in policy.parameters(): + param.requires_grad_(False) + policy._unfreeze_action_expert_parameters() + + trainable = [name for name, param in policy.named_parameters() if param.requires_grad] + assert trainable + assert all("action_expert" in name for name in trainable) + + +def test_train_action_expert_only_requires_continuous_action_mode(): + with pytest.raises(ValueError, match="requires action_mode='continuous'"): + MolmoAct2Config(action_mode="both", train_action_expert_only=True) + + with pytest.raises(ValueError, match="incompatible with enable_lora_vlm"): + MolmoAct2Config(action_mode="continuous", train_action_expert_only=True, enable_lora_vlm=True) + + cfg = MolmoAct2Config(action_mode="continuous", train_action_expert_only=True) + assert cfg.train_action_expert_only + + +def test_molmoact2_sequence_length_is_inferred_from_fixed_token_budget(): + cfg = MolmoAct2Config( + action_mode="both", + chunk_size=10, + n_action_steps=10, + image_keys=["observation.images.image", "observation.images.wrist_image"], + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))}, + ) + + assert cfg.max_sequence_length is None + assert cfg.inferred_max_sequence_length() == 640 + assert cfg.inferred_max_sequence_length(include_discrete_action=False) == 576 + assert ( + infer_molmoact2_max_sequence_length( + num_images=2, + state_dim=8, + action_dim=7, + action_horizon=30, + include_discrete_action=True, + ) + == 768 + ) + + +def test_molmoact2_sequence_length_override_is_preserved(): + cfg = MolmoAct2Config(max_sequence_length=1024) + + assert cfg.inferred_max_sequence_length(num_images=2, state_dim=8, action_dim=7) == 1024 + + +def test_train_action_expert_only_freezes_non_action_expert_params(): + class DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.transformer = torch.nn.Linear(2, 2) + self.vision_backbone = torch.nn.Linear(2, 2) + self.action_expert = torch.nn.Linear(2, 2) + + def _require_action_expert(self): + return self.action_expert + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = DummyBackbone() + self.lm_head = torch.nn.Linear(2, 2) + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(train_action_expert_only=True) + policy.model = DummyModel() + + policy._freeze_non_action_expert_parameters() + policy.train() + + assert policy.model.model.action_expert.training + assert not policy.model.training + assert not policy.model.model.transformer.training + assert all(param.requires_grad for param in policy.model.model.action_expert.parameters()) + assert not any(param.requires_grad for param in policy.model.model.transformer.parameters()) + assert not any(param.requires_grad for param in policy.model.model.vision_backbone.parameters()) + assert not any(param.requires_grad for param in policy.model.lm_head.parameters()) + + +def test_load_hf_model_accepts_max_action_horizon_schema(monkeypatch): + class DummyLoadedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + max_action_dim=32, + max_action_horizon=30, + action_mode="both", + add_action_expert=True, + ) + self.model = torch.nn.Module() + self.embed_tokens = torch.nn.Embedding(4, 4) + self.lm_head = torch.nn.Linear(4, 4, bias=False) + + def get_input_embeddings(self): + return self.embed_tokens + + loaded_model = DummyLoadedModel() + resolved_kwargs = {} + + def fake_resolve_checkpoint_location(checkpoint_path, **kwargs): + resolved_kwargs.update(kwargs) + return checkpoint_path + + monkeypatch.setattr(molmoact2_modeling, "_resolve_checkpoint_location", fake_resolve_checkpoint_location) + monkeypatch.setattr(molmoact2_modeling, "_patch_batched_image_attention_bias", lambda backbone: None) + monkeypatch.setattr(molmoact2_modeling, "_patch_leaf_safe_input_embedding_update", lambda backbone: None) + monkeypatch.setattr(molmoact2_modeling, "_patch_training_kv_collection", lambda backbone: None) + + from transformers import AutoModelForImageTextToText + + monkeypatch.setattr( + AutoModelForImageTextToText, + "from_pretrained", + lambda *args, **kwargs: loaded_model, + ) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + checkpoint_path="/tmp/new-schema-checkpoint", + checkpoint_revision="main", + checkpoint_force_download=True, + chunk_size=10, + n_action_steps=10, + action_mode="both", + ) + + policy._load_hf_model() + + assert policy.model is loaded_model + assert not hasattr(policy.model.config, "action_horizon") + assert policy.model.config.max_action_horizon == 10 + assert policy._generation_action_horizon() == 10 + assert resolved_kwargs == {"revision": "main", "force_download": True} + + +def test_load_hf_model_chunk_size_overrides_larger_than_checkpoint_horizon(monkeypatch): + class DummyLoadedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + max_action_dim=32, + max_action_horizon=10, + action_mode="both", + add_action_expert=True, + ) + self.model = torch.nn.Module() + self.embed_tokens = torch.nn.Embedding(4, 4) + self.lm_head = torch.nn.Linear(4, 4, bias=False) + + def get_input_embeddings(self): + return self.embed_tokens + + loaded_model = DummyLoadedModel() + monkeypatch.setattr( + molmoact2_modeling, + "_resolve_checkpoint_location", + lambda checkpoint_path, **kwargs: checkpoint_path, + ) + monkeypatch.setattr(molmoact2_modeling, "_patch_batched_image_attention_bias", lambda backbone: None) + monkeypatch.setattr(molmoact2_modeling, "_patch_leaf_safe_input_embedding_update", lambda backbone: None) + monkeypatch.setattr(molmoact2_modeling, "_patch_training_kv_collection", lambda backbone: None) + + from transformers import AutoModelForImageTextToText + + monkeypatch.setattr( + AutoModelForImageTextToText, + "from_pretrained", + lambda *args, **kwargs: loaded_model, + ) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + checkpoint_path="/tmp/new-schema-checkpoint", + chunk_size=30, + n_action_steps=30, + action_mode="both", + ) + + policy._load_hf_model() + + assert policy.model.config.max_action_horizon == 30 + assert policy._generation_action_horizon() == 30 + + +def test_load_hf_model_rejects_legacy_action_horizon_schema(monkeypatch): + class DummyLoadedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + max_action_dim=32, + action_horizon=30, + action_mode="both", + add_action_expert=True, + ) + self.model = torch.nn.Module() + + monkeypatch.setattr( + molmoact2_modeling, + "_resolve_checkpoint_location", + lambda checkpoint_path, **kwargs: checkpoint_path, + ) + + from transformers import AutoModelForImageTextToText + + monkeypatch.setattr( + AutoModelForImageTextToText, + "from_pretrained", + lambda *args, **kwargs: DummyLoadedModel(), + ) + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = MolmoAct2Config( + checkpoint_path="/tmp/legacy-schema-checkpoint", + chunk_size=10, + n_action_steps=10, + action_mode="both", + ) + + with pytest.raises(ValueError, match="max_action_horizon"): + policy._load_hf_model() + + +def test_rtc_processor_initialization_and_select_action_guard(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(rtc_config=RTCConfig(enabled=True)) + + policy.init_rtc_processor() + + assert policy.rtc_processor is not None + with pytest.raises(AssertionError, match="RTC is not supported for select_action"): + policy.select_action({}) + + +def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(action_mode="both", inference_action_mode=None) + policy._checkpoint_action_mode = None + + with pytest.raises(ValueError, match="inference_action_mode.*explicitly"): + policy._resolve_inference_action_mode(None) + with pytest.raises(TypeError, match="unexpected keyword argument 'action_mode'"): + policy.predict_action_chunk({}, action_mode="continuous") + + +def test_rtc_generation_uses_previous_chunk_prefix(): + class DummyActionExpert(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + + def prepare_context(self, **kwargs): + del kwargs + return SimpleNamespace() + + def get_or_prepare_modulation_cache(self, timesteps, *, cache_key=None): + del cache_key + return [SimpleNamespace(conditioning=timestep) for timestep in timesteps] + + def forward_with_context(self, actions, timesteps, *, context, modulation=None): + del timesteps, context, modulation + return torch.ones_like(actions) * self.weight + + class DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + flow_matching_num_steps=2, + max_action_horizon=4, + max_action_dim=3, + ) + self.action_expert = DummyActionExpert() + self.batch_size = 1 + + def _require_action_expert(self): + return self.action_expert + + def forward(self, **kwargs): + self.batch_size = int(kwargs["input_ids"].shape[0]) + return SimpleNamespace(past_key_values=object()) + + def _extract_kv_states(self, past_key_values): + del past_key_values + kv = torch.zeros(self.batch_size, 1, 1) + return [(kv, kv)] + + def _get_encoder_attention_mask(self, input_ids, attention_mask): + del input_ids + return attention_mask + + def _depth_gate_from_condition(self, **kwargs): + del kwargs + return None, None + + def _apply_depth_gate_to_layer_kv_states(self, encoder_kv_states, depth_mask, depth_gate): + del depth_mask, depth_gate + return encoder_kv_states + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + mask_action_dim_padding=True, + rtc_config=RTCConfig(enabled=True, execution_horizon=2, max_guidance_weight=1.0), + ) + policy.rtc_processor = None + policy.model = torch.nn.Module() + policy.model.model = DummyBackbone() + policy.init_rtc_processor() + model_inputs = { + "input_ids": torch.ones(1, 2, dtype=torch.long), + "attention_mask": torch.ones(1, 2, dtype=torch.long), + } + action_dim_is_pad = torch.tensor([[False, False, False]]) + + without_prefix = policy._generate_actions_from_inputs_with_rtc( + model_inputs=model_inputs, + action_dim_is_pad=action_dim_is_pad, + num_steps=2, + generator=torch.Generator().manual_seed(0), + inference_delay=0, + prev_chunk_left_over=None, + execution_horizon=None, + ) + with_prefix = policy._generate_actions_from_inputs_with_rtc( + model_inputs=model_inputs, + action_dim_is_pad=action_dim_is_pad, + num_steps=2, + generator=torch.Generator().manual_seed(0), + inference_delay=0, + prev_chunk_left_over=torch.zeros(1, 4, 3), + execution_horizon=None, + ) + + assert without_prefix.shape == (1, 4, 3) + assert not torch.allclose(without_prefix, with_prefix) + + +def test_discrete_state_string_matches_molmoact2_bins(): + state = np.asarray([-1.0, 0.0, 1.0, np.nan, np.inf, -np.inf], dtype=np.float32) + + assert _build_discrete_state_string(state, 256) == ( + "" + ) + + +def test_question_normalization_matches_release_prompt_style(): + assert _normalize_question_text("Instruction: Pick up the cube, please!") == "pick up the cube, please" + assert ( + _normalize_question_text("The task is to open drawer. Then close it.") == "open drawer; then close it" + ) + + +def test_action_padding_marks_only_real_dimensions(): + step = object.__new__(MolmoAct2PackInputsProcessorStep) + step.max_action_dim = 32 + action = torch.ones(2, 3, 7) + + padded, horizon_is_pad, dim_is_pad = step._pad_action(action, None) + + assert padded.shape == (2, 3, 32) + assert torch.equal(padded[..., :7], action) + assert torch.count_nonzero(padded[..., 7:]) == 0 + assert not horizon_is_pad.any() + assert not dim_is_pad[:, :7].any() + assert dim_is_pad[:, 7:].all() + + +def test_action_dim_padding_loss_reduces_like_old_trainer(): + loss = torch.arange(2 * 2 * 3 * 4, dtype=torch.float32).reshape(2, 2, 3, 4) + action_dim_is_pad = torch.tensor( + [ + [False, False, True, True], + [False, True, True, True], + ] + ) + + reduced = MolmoAct2Policy._apply_action_dim_padding_mask(loss, action_dim_is_pad) + + expected = torch.stack( + [ + loss[0, :, :, :2].sum(dim=-1) / 2, + loss[1, :, :, :1].sum(dim=-1) / 1, + ], + dim=0, + ) + assert torch.equal(reduced, expected) + + +def test_action_chunk_padding_keeps_old_mean_denominator(): + loss = torch.ones(1, 2, 4, 3) + action_horizon_is_pad = torch.tensor([[False, False, True, True]]) + + masked = MolmoAct2Policy._apply_action_chunk_padding_mask(loss, action_horizon_is_pad) + + assert masked.mean().item() == 0.5 + + +def test_selected_discrete_loss_matches_full_causal_lm_loss(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=False, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="none", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(2, 4, 3) + labels = torch.tensor( + [ + [-100, 1, 2, -100], + [-100, -100, 3, 4], + ] + ) + + selected_loss, z_loss = policy._discrete_loss_from_backbone_outputs({"labels": labels}, outputs) + + logits = policy.model.lm_head(outputs.last_hidden_state) + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + expected_loss = F.cross_entropy(logits.float().view(-1, 5), shift_labels.view(-1), ignore_index=-100) + assert torch.allclose(selected_loss, expected_loss) + assert z_loss is None + + +def test_discrete_z_loss_matches_old_trainer_formula(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="none", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(2, 4, 3) + labels = torch.tensor( + [ + [-100, 1, 2, -100], + [-100, -100, 3, 4], + ] + ) + + ce_loss, z_loss = policy._discrete_loss_from_backbone_outputs({"labels": labels}, outputs) + + logits = policy.model.lm_head(outputs.last_hidden_state).float() + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + valid = shift_labels != -100 + expected_ce = F.cross_entropy(logits.view(-1, 5), shift_labels.view(-1), ignore_index=-100) + expected_z = 1e-4 * logits.logsumexp(dim=-1)[valid].pow(2).mean() + assert torch.allclose(ce_loss, expected_ce) + assert z_loss is not None + assert torch.allclose(z_loss, expected_z) + + +def test_discrete_reduction_none_preserves_mean_loss(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="root_subsegments_root_tokens", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(3, 5, 3) + labels = torch.tensor( + [ + [-100, 1, -100, -100, -100], + [-100, -100, 2, 3, -100], + [-100, 4, 3, 2, 1], + ] + ) + + ce_mean, z_mean = policy._discrete_loss_from_backbone_outputs( + {"labels": labels}, + outputs, + reduction="mean", + ) + ce_none, z_none = policy._discrete_loss_from_backbone_outputs( + {"labels": labels}, + outputs, + reduction="none", + ) + + assert ce_none.shape == (3,) + assert z_none is not None + assert z_none.shape == (3,) + assert torch.allclose(ce_none.mean(), ce_mean) + assert torch.allclose(z_none.mean(), z_mean) + + +def test_forward_reduction_none_returns_per_sample_discrete_loss(): + class DummyBackbone(torch.nn.Module): + def __init__(self, hidden_states): + super().__init__() + self.hidden_states = hidden_states + + def forward(self, **kwargs): + del kwargs + return SimpleNamespace(last_hidden_state=self.hidden_states) + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + action_mode="discrete", + inference_action_mode="discrete", + model_dtype="float32", + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="none", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + hidden_states = torch.randn(2, 4, 3) + policy._backbone = lambda: DummyBackbone(hidden_states) + batch = { + "input_ids": torch.ones(2, 4, dtype=torch.long), + "labels": torch.tensor( + [ + [-100, 1, 2, -100], + [-100, -100, 3, 4], + ] + ), + } + + loss_none, metrics_none = policy.forward(batch, reduction="none") + loss_mean, metrics_mean = policy.forward(batch, reduction="mean") + + assert loss_none.shape == (2,) + assert torch.allclose(loss_none.mean(), loss_mean) + assert metrics_none["loss"] == pytest.approx(metrics_mean["loss"]) + + +def test_discrete_root_token_weighting_matches_old_loss_mask_scaling(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + softmax_auxiliary_loss=True, + softmax_auxiliary_loss_scale=1e-4, + discrete_loss_token_weighting="root_subsegments_root_tokens", + ) + policy.model = torch.nn.Module() + policy.model.lm_head = torch.nn.Linear(3, 5, bias=False) + outputs = type("Outputs", (), {})() + outputs.last_hidden_state = torch.randn(2, 4, 3) + labels = torch.tensor( + [ + [-100, -100, 1, -100], + [-100, 2, 3, 4], + ] + ) + + ce_loss, z_loss = policy._discrete_loss_from_backbone_outputs({"labels": labels}, outputs) + + logits = policy.model.lm_head(outputs.last_hidden_state).float() + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + valid = shift_labels != -100 + log_z = logits.logsumexp(dim=-1) + token_ce = log_z - logits.gather(dim=-1, index=shift_labels.clamp_min(0).unsqueeze(-1)).squeeze(-1) + weights = torch.zeros_like(token_ce) + counts = valid.sum(dim=1).float() + weights[valid] = (2.0 / torch.sqrt(counts))[:, None].expand_as(weights)[valid] + expected_ce = (token_ce * weights).sum() / weights.sum() + expected_z = 1e-4 * (log_z.pow(2) * weights).sum() / weights.sum() + assert torch.allclose(ce_loss, expected_ce) + assert z_loss is not None + assert torch.allclose(z_loss, expected_z) + + +class _DummyActionTokenizer: + def decode(self, tokens, *, time_horizon=None, action_dim=None): + decoded = [] + for token_row in tokens: + decoded.append(np.full((time_horizon, action_dim), sum(token_row), dtype=np.float32)) + return np.stack(decoded) + + +def test_discrete_decode_extracts_action_bins_for_each_batch(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(chunk_size=2) + policy.action_tokenizer = _DummyActionTokenizer() + policy.model = torch.nn.Module() + policy.model.config = SimpleNamespace( + action_start_token_id=10, + action_end_token_id=11, + action_token_start_id=100, + num_action_tokens=4, + action_horizon=2, + ) + + actions = policy._decode_discrete_action_chunk( + torch.tensor( + [ + [10, 100, 101, 11, 2], + [10, 102, 103, 11, 2], + ] + ), + action_dim=2, + ) + + assert actions.shape == (2, 2, 2) + assert torch.equal(actions[0], torch.ones(2, 2)) + assert torch.equal(actions[1], torch.full((2, 2), 5.0)) + + +def test_discrete_predict_action_chunk_uses_hf_cached_generation_path(): + class DummyOutput: + def __init__(self, token_id, batch_size): + logits = torch.full((batch_size, 1, 128), -1e9) + logits[:, :, token_id] = 1.0 + self.logits = logits + self.past_key_values = object() + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + self.config = SimpleNamespace( + action_start_token_id=10, + action_end_token_id=11, + action_token_start_id=100, + num_action_tokens=4, + action_horizon=2, + ) + self.tokens = [10, 100, 101, 11, 2] + self.index = 0 + + def forward(self, **kwargs): + batch_size = int(kwargs["input_ids"].shape[0]) + return DummyOutput(self.tokens[self.index], batch_size) + + def _consume_generation_tokens(self, token_ids, *, past_key_values, attention_mask): + del past_key_values + self.index += 1 + if attention_mask is not None: + attention_mask = torch.cat([attention_mask, torch.ones_like(token_ids[:, None])], dim=-1) + return DummyOutput(self.tokens[self.index], int(token_ids.shape[0])), attention_mask + + def _require_eos_token_id(self): + return 2 + + def _action_token_id_to_bin(self): + return {100: 0, 101: 1, 102: 2, 103: 3} + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + action_mode="discrete", + inference_action_mode="discrete", + model_dtype="float32", + output_features={ACTION: SimpleNamespace(shape=(2,))}, + discrete_generation_max_steps=None, + discrete_action_tokenizer="unused", + trust_remote_code=True, + chunk_size=2, + n_action_steps=1, + rtc_config=None, + ) + policy.model = DummyModel() + policy.action_tokenizer = _DummyActionTokenizer() + + actions = policy.predict_action_chunk( + { + "input_ids": torch.ones(1, 3, dtype=torch.long), + "attention_mask": torch.ones(1, 3, dtype=torch.long), + } + ) + + assert policy.model.index == 4 + assert actions.shape == (1, 1, 2) + assert torch.equal(actions, torch.ones(1, 1, 2)) + + +def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled(): + class DummyOutput: + def __init__(self, token_id, past_key_values): + logits = torch.full((1, 1, 128), -1e9) + logits[:, :, token_id] = 1.0 + self.logits = logits + self.past_key_values = past_key_values + + class DummyLmHead(torch.nn.Module): + def forward(self, hidden_states): + token_id = int(hidden_states[0, 0, 0].item()) + logits = torch.full((1, 1, 128), -1e9) + logits[:, :, token_id] = 1.0 + return logits + + class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + self.lm_head = DummyLmHead() + self.config = SimpleNamespace( + action_start_token_id=10, + action_end_token_id=11, + action_token_start_id=100, + num_action_tokens=4, + action_horizon=2, + ) + self.tokens = [10, 100, 101, 11, 2] + self.index = 0 + self.used_static_cache = False + self.graph_steps = 0 + + def forward(self, **kwargs): + self.used_static_cache = kwargs.get("past_key_values") == "static-cache" + return DummyOutput(self.tokens[self.index], kwargs.get("past_key_values")) + + def _make_ar_decode_static_cache(self, inputs, *, max_steps): + assert int(inputs["input_ids"].shape[1]) == 3 + assert max_steps == 32 + return "static-cache" + + def _make_depth_decode_attention_bias(self, inputs, past_key_values): + assert past_key_values == "static-cache" + return torch.ones(1, 1, 35, 35, dtype=torch.float32) + + def _run_ar_decode_step(self, token_ids, *, past_key_values, attention_bias): + assert past_key_values == "static-cache" + assert attention_bias.shape == (1, 1, 35, 35) + self.index += 1 + self.graph_steps += 1 + return torch.tensor([[[float(self.tokens[self.index])]]]), past_key_values + + def _require_eos_token_id(self): + return 2 + + def _action_token_id_to_bin(self): + return {100: 0, 101: 1, 102: 2, 103: 3} + + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace( + action_mode="discrete", + inference_action_mode="discrete", + model_dtype="float32", + output_features={ACTION: SimpleNamespace(shape=(2,))}, + discrete_generation_max_steps=None, + discrete_action_tokenizer="unused", + trust_remote_code=True, + chunk_size=2, + n_action_steps=1, + rtc_config=None, + enable_inference_cuda_graph=True, + ) + policy.model = DummyModel() + policy.action_tokenizer = _DummyActionTokenizer() + torch.nn.Module.train(policy, False) + + actions = policy.predict_action_chunk( + { + "input_ids": torch.ones(1, 3, dtype=torch.long), + "attention_mask": torch.ones(1, 3, dtype=torch.long), + } + ) + + assert policy.model.used_static_cache + assert policy.model.graph_steps == 4 + assert actions.shape == (1, 1, 2) + assert torch.equal(actions, torch.ones(1, 1, 2)) + + +class _DummyMolmoBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed = torch.nn.Embedding(5, 3) + + def get_input_embeddings(self): + return self.embed + + +class _DummyMolmoModel(torch.nn.Module): + def __init__(self, *, tie_lm_head: bool = False): + super().__init__() + self.model = _DummyMolmoBackbone() + self.lm_head = torch.nn.Linear(3, 5, bias=False) + if tie_lm_head: + self.lm_head.weight = self.model.embed.weight + + def get_input_embeddings(self): + return self.model.embed + + +def test_freeze_embedding_freezes_input_embeddings_only_when_untied(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.model = _DummyMolmoModel() + + policy._freeze_input_embeddings() + + assert not policy.model.model.embed.weight.requires_grad + assert policy.model.lm_head.weight.requires_grad + + +def test_freeze_embedding_rejects_tied_lm_head_without_mutating(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.model = _DummyMolmoModel(tie_lm_head=True) + + with pytest.raises(RuntimeError, match="would also freeze lm_head"): + policy._freeze_input_embeddings() + + assert policy.model.model.embed.weight.requires_grad diff --git a/uv.lock b/uv.lock index c5f026517..265265109 100644 --- a/uv.lock +++ b/uv.lock @@ -2915,6 +2915,10 @@ metaworld = [ { name = "scipy" }, { name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] +molmoact2 = [ + { name = "peft" }, + { name = "transformers" }, +] motorbridge-dep = [ { name = "motorbridge" }, ] @@ -3128,6 +3132,7 @@ requires-dist = [ { name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" }, { name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["molmoact2"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["motorbridge-dep"], marker = "extra == 'rebot'" }, { name = "lerobot", extras = ["motorbridge-smart-servo-dep"], marker = "extra == 'rebot'" }, { name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" }, @@ -3135,6 +3140,7 @@ requires-dist = [ { name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["peft"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" }, + { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'molmoact2'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["phone"], marker = "extra == 'all'" }, @@ -3172,6 +3178,7 @@ requires-dist = [ { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" }, + { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" }, @@ -3244,7 +3251,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt"