From 8df8d3d8665802f824a2376cb78a0534469341d9 Mon Sep 17 00:00:00 2001 From: javadcc_mac Date: Sat, 9 May 2026 21:39:19 +0800 Subject: [PATCH] feat(policies): add EVO1 policy --- docs/source/_toctree.yml | 2 + docs/source/evo1.mdx | 132 +++++ pyproject.toml | 2 + src/lerobot/policies/__init__.py | 2 + src/lerobot/policies/evo1/__init__.py | 19 + .../policies/evo1/configuration_evo1.py | 211 ++++++++ src/lerobot/policies/evo1/evo1_model.py | 234 +++++++++ src/lerobot/policies/evo1/flow_matching.py | 456 ++++++++++++++++++ .../policies/evo1/internvl3_embedder.py | 366 ++++++++++++++ src/lerobot/policies/evo1/modeling_evo1.py | 419 ++++++++++++++++ src/lerobot/policies/evo1/processor_evo1.py | 106 ++++ src/lerobot/policies/factory.py | 18 +- tests/policies/evo1/test_evo1.py | 225 +++++++++ 13 files changed, 2190 insertions(+), 2 deletions(-) create mode 100644 docs/source/evo1.mdx create mode 100644 src/lerobot/policies/evo1/__init__.py create mode 100644 src/lerobot/policies/evo1/configuration_evo1.py create mode 100644 src/lerobot/policies/evo1/evo1_model.py create mode 100644 src/lerobot/policies/evo1/flow_matching.py create mode 100644 src/lerobot/policies/evo1/internvl3_embedder.py create mode 100644 src/lerobot/policies/evo1/modeling_evo1.py create mode 100644 src/lerobot/policies/evo1/processor_evo1.py create mode 100644 tests/policies/evo1/test_evo1.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 40cec863f..58a43d887 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -49,6 +49,8 @@ title: π₀.₅ (Pi05) - local: eo1 title: EO-1 + - local: evo1 + title: EVO1 - local: groot title: NVIDIA GR00T N1.5 - local: xvla diff --git a/docs/source/evo1.mdx b/docs/source/evo1.mdx new file mode 100644 index 000000000..a86f7a56b --- /dev/null +++ b/docs/source/evo1.mdx @@ -0,0 +1,132 @@ +# EVO1 + +EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs. + +## Model Overview + +The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again. + +### What the LeRobot Integration Covers + +- Standard `policy.type=evo1` configuration through LeRobot +- InternVL3 image/text embedding with optional FlashAttention fallback +- Stage-based finetuning controls for action-head-only and VLM finetuning runs +- Continuous flow-matching action prediction +- Checkpoint save/load through LeRobot policy APIs +- Training with `lerobot-train` and evaluation with standard policy inference APIs + +The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path. + +## Installation Requirements + +1. Install LeRobot by following the [Installation Guide](./installation). +2. Install EVO1 dependencies: + + ```bash + pip install -e ".[evo1]" + ``` + +3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available. + +EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory. + +## Data Requirements + +EVO1 expects a LeRobot dataset with: + +- One to `policy.max_views` visual observations, for example `observation.images.image` +- `observation.state` +- `action` +- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field` + +State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned. + +## Usage + +To use EVO1 in a LeRobot configuration, specify: + +```python +policy.type=evo1 +``` + +By default, a new EVO1 policy initializes its VLM from: + +```python +policy.vlm_model_name=OpenGVLab/InternVL3-1B +``` + +Once a LeRobot-format EVO1 checkpoint is available, load it with: + +```python +policy.path=your-org/your-evo1-checkpoint +``` + +## Training + +### Stage 1 + +Stage 1 freezes the VLM and trains the action head: + +```bash +lerobot-train \ + --dataset.repo_id=your_org/your_dataset \ + --policy.type=evo1 \ + --policy.training_stage=stage1 \ + --policy.vlm_model_name=OpenGVLab/InternVL3-1B \ + --policy.device=cuda \ + --policy.chunk_size=50 \ + --policy.n_action_steps=50 \ + --policy.max_state_dim=24 \ + --policy.max_action_dim=24 \ + --policy.optimizer_lr=1e-5 \ + --batch_size=4 \ + --steps=5000 \ + --output_dir=./outputs/evo1_stage1 +``` + +### Stage 2 + +Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint: + +```bash +lerobot-train \ + --dataset.repo_id=your_org/your_dataset \ + --policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \ + --policy.training_stage=stage2 \ + --policy.vlm_model_name=OpenGVLab/InternVL3-1B \ + --policy.device=cuda \ + --policy.chunk_size=50 \ + --policy.n_action_steps=50 \ + --policy.max_state_dim=24 \ + --policy.max_action_dim=24 \ + --policy.optimizer_lr=1e-5 \ + --batch_size=4 \ + --steps=80000 \ + --output_dir=./outputs/evo1_stage2 +``` + +### Key Training Parameters + +| Parameter | Default | Description | +| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- | +| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory | +| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches | +| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy | +| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype | +| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back | +| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules | +| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported | +| `policy.chunk_size` | `50` | Number of future actions predicted per chunk | +| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk | +| `policy.max_state_dim` | `24` | State padding dimension | +| `policy.max_action_dim` | `24` | Action padding dimension | +| `policy.task_field` | `task` | Batch field used as the language prompt | + +## References + +- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1) +- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B) + +## License + +This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data. diff --git a/pyproject.toml b/pyproject.toml index 0ae3abd73..05b254637 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,6 +195,7 @@ groot = [ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] +evo1 = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features @@ -334,6 +335,7 @@ ignore = [ # E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect "src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"] "src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original +"src/lerobot/policies/evo1/**" = ["N801", "N812"] [tool.ruff.lint.isort] combine-as-imports = true diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 2633d04ad..a4cf2e64f 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -17,6 +17,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .eo1.configuration_eo1 import EO1Config as EO1Config +from .evo1.configuration_evo1 import Evo1Config as Evo1Config from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors from .groot.configuration_groot import GrootConfig as GrootConfig from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig @@ -40,6 +41,7 @@ __all__ = [ # Configuration classes "ACTConfig", "DiffusionConfig", + "Evo1Config", "GrootConfig", "MultiTaskDiTConfig", "EO1Config", diff --git a/src/lerobot/policies/evo1/__init__.py b/src/lerobot/policies/evo1/__init__.py new file mode 100644 index 000000000..f15a27d8c --- /dev/null +++ b/src/lerobot/policies/evo1/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_evo1 import Evo1Config +from .modeling_evo1 import EVO1Policy +from .processor_evo1 import make_evo1_pre_post_processors + +__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"] diff --git a/src/lerobot/policies/evo1/configuration_evo1.py b/src/lerobot/policies/evo1/configuration_evo1.py new file mode 100644 index 000000000..4cfec4d28 --- /dev/null +++ b/src/lerobot/policies/evo1/configuration_evo1.py @@ -0,0 +1,211 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from dataclasses import dataclass, field + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +@LRSchedulerConfig.register_subclass("evo1_exact") +@dataclass +class Evo1SchedulerConfig(LRSchedulerConfig): + num_warmup_steps: int + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + def lr_lambda(current_step: int) -> float: + if current_step < self.num_warmup_steps: + return current_step / max(1, self.num_warmup_steps) + progress = (current_step - self.num_warmup_steps) / max( + 1, num_training_steps - self.num_warmup_steps + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + return LambdaLR(optimizer, lr_lambda, -1) + + +@PreTrainedConfig.register_subclass("evo1") +@dataclass +class Evo1Config(PreTrainedConfig): + training_stage: str = "stage1" + use_amp: bool = True + + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + max_state_dim: int = 24 + max_action_dim: int = 24 + max_views: int = 3 + image_resolution: tuple[int, int] = (448, 448) + empty_cameras: int = 0 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + vlm_model_name: str = "OpenGVLab/InternVL3-1B" + vlm_num_layers: int | None = 14 + vlm_dtype: str = "bfloat16" + use_flash_attn: bool = True + action_head: str = "flowmatching" + embed_dim: int = 896 + hidden_dim: int = 1024 + state_hidden_dim: int = 1024 + num_heads: int = 8 + num_layers: int = 8 + dropout: float = 0.0 + num_inference_timesteps: int = 32 + num_categories: int = 1 + return_cls_only: bool = False + enable_gradient_checkpointing: bool = True + gradient_checkpointing_use_reentrant: bool = False + + finetune_vlm: bool | None = None + finetune_language_model: bool | None = None + finetune_vision_model: bool | None = None + finetune_action_head: bool | None = None + + task_field: str = "task" + embodiment_id_field: str | None = None + default_embodiment_id: int = 0 + + optimizer_lr: float = 1e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-5 + optimizer_grad_clip_norm: float = 1.0 + + scheduler_warmup_steps: int = 300 + drop_last: bool = True + + def __post_init__(self): + super().__post_init__() + if self.training_stage not in {"stage1", "stage2"}: + raise ValueError( + f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'" + ) + + if self.training_stage == "stage1": + if self.finetune_vlm is None: + self.finetune_vlm = False + if self.finetune_language_model is None: + self.finetune_language_model = False + if self.finetune_vision_model is None: + self.finetune_vision_model = False + if self.finetune_action_head is None: + self.finetune_action_head = True + elif self.training_stage == "stage2": + has_explicit_branch_flags = any( + flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model) + ) + if not has_explicit_branch_flags: + if self.finetune_vlm is None: + self.finetune_vlm = True + if self.finetune_language_model is None: + self.finetune_language_model = True + if self.finetune_vision_model is None: + self.finetune_vision_model = True + elif self.finetune_vlm is None: + self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model) + if self.finetune_action_head is None: + self.finetune_action_head = True + + if self.finetune_vlm is None: + self.finetune_vlm = False + if self.finetune_language_model is None: + self.finetune_language_model = False + if self.finetune_vision_model is None: + self.finetune_vision_model = False + if self.finetune_action_head is None: + self.finetune_action_head = False + + branch_vlm = self.finetune_language_model or self.finetune_vision_model + if self.finetune_vlm != branch_vlm: + raise ValueError( + "Inconsistent EVO1 finetune config: " + f"finetune_vlm={self.finetune_vlm} but " + f"(finetune_language_model or finetune_vision_model)={branch_vlm}. " + "When branch-level flags are used, finetune_vlm must match their effective union." + ) + + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})" + ) + + def validate_features(self) -> None: + if self.input_features is None: + self.input_features = {} + if self.output_features is None: + self.output_features = {} + + for i in range(self.empty_cameras): + key = OBS_IMAGES + f".empty_camera_{i}" + if key not in self.input_features: + self.input_features[key] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), + ) + + if OBS_STATE not in self.input_features: + self.input_features[OBS_STATE] = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), + ) + + if ACTION not in self.output_features: + self.output_features[ACTION] = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return Evo1SchedulerConfig( + num_warmup_steps=self.scheduler_warmup_steps, + ) + + @property + def observation_delta_indices(self) -> list[int]: + return [0] + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/evo1/evo1_model.py b/src/lerobot/policies/evo1/evo1_model.py new file mode 100644 index 000000000..18e4dab33 --- /dev/null +++ b/src/lerobot/policies/evo1/evo1_model.py @@ -0,0 +1,234 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import torch +import torch.nn as nn +from PIL import Image + +from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead +from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder + + +def _cfgget(config: Any, key: str, default=None): + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +class EVO1(nn.Module): + def __init__(self, config: dict): + super().__init__() + self.config = config + self._device = _cfgget(config, "device", "cuda") + self.return_cls_only = _cfgget(config, "return_cls_only", False) + vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B") + image_size = _cfgget(config, "image_size", 448) + if image_size is None: + image_resolution = _cfgget(config, "image_resolution", (448, 448)) + image_size = int(image_resolution[0]) + + self.embedder = InternVL3Embedder( + model_name=vlm_name, + image_size=image_size, + device=self._device, + num_language_layers=_cfgget(config, "vlm_num_layers", 14), + model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"), + use_flash_attn=_cfgget(config, "use_flash_attn", True), + enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True), + gradient_checkpointing_use_reentrant=_cfgget( + config, "gradient_checkpointing_use_reentrant", False + ), + ) + + action_head_type = _cfgget(config, "action_head", "flowmatching").lower() + if action_head_type != "flowmatching": + raise NotImplementedError(f"Unknown action_head: {action_head_type}") + + horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16)) + per_action_dim = _cfgget(config, "per_action_dim", 7) + action_dim = horizon * per_action_dim + + if isinstance(config, dict): + config["horizon"] = horizon + config["per_action_dim"] = per_action_dim + config["action_dim"] = action_dim + + self.horizon = horizon + self.per_action_dim = per_action_dim + self.action_head = FlowmatchingActionHead(config=config).to(self._device) + + def _normalize_image_batches( + self, + images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]], + prompt: str | list[str] | None, + image_mask: torch.Tensor, + ) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]: + if not images: + raise ValueError("EVO1 expects at least one image per sample.") + + first = images[0] + if isinstance(first, (Image.Image, torch.Tensor)): + image_batches = [list(images)] # type: ignore[arg-type] + else: + image_batches = [list(sample) for sample in images] # type: ignore[arg-type] + + batch_size = len(image_batches) + if prompt is None: + prompts = [""] * batch_size + elif isinstance(prompt, str): + prompts = [prompt] * batch_size + else: + prompts = [str(p) for p in prompt] + if len(prompts) != batch_size: + raise ValueError( + f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}" + ) + + if image_mask.dim() == 1: + image_mask = image_mask.unsqueeze(0) + if image_mask.shape[0] != batch_size: + raise ValueError( + f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}" + ) + + return image_batches, prompts, image_mask + + def get_vl_embeddings( + self, + images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]], + image_mask: torch.Tensor, + prompt: str | list[str] | None = None, + return_cls_only: bool | None = None, + ) -> torch.Tensor: + if return_cls_only is None: + return_cls_only = self.return_cls_only + + image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask) + return self.embedder.get_fused_image_text_embedding_from_tensor_images( + image_tensors_batch=image_batches, + image_masks=image_mask, + text_prompts=prompts, + return_cls_only=return_cls_only, + ) + + def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor: + if isinstance(state_input, list): + state_tensor = torch.tensor(state_input) + elif isinstance(state_input, torch.Tensor): + state_tensor = state_input + else: + raise TypeError(f"Unsupported state input type: {type(state_input)}") + + if state_tensor.ndim == 1: + state_tensor = state_tensor.unsqueeze(0) + + return state_tensor.to(self._device) + + def predict_action( + self, + fused_tokens: torch.Tensor, + state: torch.Tensor, + actions_gt: torch.Tensor | None = None, + action_mask: torch.Tensor | None = None, + embodiment_ids: torch.Tensor | None = None, + ): + if actions_gt is None: + return self.action_head.get_action( + fused_tokens, + state=state, + action_mask=action_mask, + embodiment_id=embodiment_ids, + ) + return self.action_head( + fused_tokens, + state=state, + actions_gt=actions_gt, + action_mask=action_mask, + embodiment_id=embodiment_ids, + ) + + @torch.no_grad() + def run_inference( + self, + images: list[Image.Image | torch.Tensor], + image_mask: torch.Tensor, + prompt: str, + state_input: list | torch.Tensor, + return_cls_only: bool | None = None, + action_mask: torch.Tensor | None = None, + embodiment_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + if image_mask.dim() == 1: + image_mask = image_mask.unsqueeze(0) + + fused_tokens = self.get_vl_embeddings( + images=[images], + image_mask=image_mask, + prompt=[prompt], + return_cls_only=return_cls_only, + ) + state_tensor = self.prepare_state(state_input) + action = self.predict_action( + fused_tokens, + state_tensor, + action_mask=action_mask, + embodiment_ids=embodiment_ids, + ) + if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16: + action = action.to(torch.float32) + return action + + def forward( + self, + fused_tokens: torch.Tensor, + state: torch.Tensor | None = None, + actions_gt: torch.Tensor | None = None, + action_mask: torch.Tensor | None = None, + embodiment_ids: torch.Tensor | None = None, + ): + return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids) + + def _set_module_trainable(self, module: nn.Module, trainable: bool): + for param in module.parameters(): + param.requires_grad = trainable + + def set_finetune_flags(self): + finetune_vlm = _cfgget(self.config, "finetune_vlm", False) + finetune_language_model = _cfgget(self.config, "finetune_language_model", False) + finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False) + has_explicit_branch_flags = any( + flag is not None for flag in (finetune_language_model, finetune_vision_model) + ) + finetune_language_model = bool(finetune_language_model) + finetune_vision_model = bool(finetune_vision_model) + finetune_vlm = bool(finetune_vlm) + + if has_explicit_branch_flags: + self._set_module_trainable(self.embedder, False) + if hasattr(self.embedder.model, "language_model"): + self._set_module_trainable(self.embedder.model.language_model, finetune_language_model) + if hasattr(self.embedder.model, "vision_model"): + self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model) + if hasattr(self.embedder.model, "mlp1"): + self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model) + elif not finetune_vlm: + self._set_module_trainable(self.embedder, False) + + if not _cfgget(self.config, "finetune_action_head", False): + self._set_module_trainable(self.action_head, False) diff --git a/src/lerobot/policies/evo1/flow_matching.py b/src/lerobot/policies/evo1/flow_matching.py new file mode 100644 index 000000000..b36af406c --- /dev/null +++ b/src/lerobot/policies/evo1/flow_matching.py @@ -0,0 +1,456 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import math +from types import SimpleNamespace + +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +def _cfgget(config, key: str, default=None): + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, dim: int, max_len: int = 1000): + super().__init__() + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, seq_len: int): + if seq_len > self.pe.size(1): + self._extend_pe(seq_len) + return self.pe[:, :seq_len, :] + + def _extend_pe(self, new_max_len): + old_max_len, dim = self.pe.size(1), self.pe.size(2) + if new_max_len <= old_max_len: + return + extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)) + extra_pe = torch.zeros(new_max_len - old_max_len, dim) + extra_pe[:, 0::2] = torch.sin(extra_positions * div_term) + extra_pe[:, 1::2] = torch.cos(extra_positions * div_term) + extra_pe = extra_pe.unsqueeze(0) + new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1) + self.pe = new_pe + + +class CategorySpecificLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1): + super().__init__() + self.num_categories = num_categories + if num_categories <= 1: + self.linear = nn.Linear(in_dim, out_dim) + else: + self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim)) + self.bias = nn.Parameter(torch.zeros(num_categories, out_dim)) + nn.init.xavier_uniform_(self.weight) + + def forward(self, x: torch.Tensor, category_id: torch.LongTensor): + if self.num_categories <= 1: + if x.dtype != self.linear.weight.dtype: + x = x.to(dtype=self.linear.weight.dtype) + return self.linear(x) + + if x.dtype != self.weight.dtype: + x = x.to(dtype=self.weight.dtype) + + orig_shape = x.shape + x_flat = x.reshape(-1, orig_shape[-1]) + if category_id.dim() == 0: + cid = category_id.item() + out = x_flat @ self.weight[cid] + self.bias[cid] + else: + category_id = category_id.reshape(-1) + if category_id.numel() != x_flat.size(0): + raise ValueError( + f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}" + ) + weight_selected = self.weight[category_id] + bias_selected = self.bias[category_id] + out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected + out_shape = orig_shape[:-1] + (out.shape[-1],) + return out.view(out_shape) + + +class CategorySpecificMLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1): + super().__init__() + self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories) + self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor, category_id: torch.LongTensor): + out = self.activation(self.fc1(x, category_id)) + out = self.fc2(out, category_id) + return out + + +class MultiEmbodimentActionEncoder(nn.Module): + def __init__( + self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1 + ): + super().__init__() + self.horizon = horizon + self.embed_dim = embed_dim + self.num_categories = num_categories + + self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories) + self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories) + self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories) + + self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon) + self.activation = nn.ReLU(inplace=True) + + def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor): + batch_size, horizon, action_dim = action_seq.shape + assert self.horizon == horizon, "Action sequence length must match horizon" + + x = action_seq.reshape(batch_size * horizon, action_dim) + if category_id.dim() == 0: + cat_ids = category_id.expand(horizon * batch_size) + else: + cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon) + + out = self.activation(self.W1(x, cat_ids)) + pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype) + out = out.view(batch_size, horizon, -1) + pos_enc + out = out.view(batch_size * horizon, -1) + out = self.activation(self.W2(out, cat_ids)) + out = self.W3(out, cat_ids) + return out.view(batch_size, horizon, self.embed_dim) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0): + super().__init__() + self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim)) + + def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor): + x = self.norm1(action_tokens) + attn_out, _ = self.attn(x, context_tokens, context_tokens) + x = action_tokens + attn_out + x2 = self.norm2(x) + if time_emb is not None: + x2 = x2 + time_emb.unsqueeze(1) + ff_out = self.ff(x2) + return x + ff_out + + +class FlowmatchingActionHead(nn.Module): + def __init__( + self, + config=None, + embed_dim: int = 896, + hidden_dim: int = 1024, + action_dim: int = 16 * 7, + horizon: int = 16, + per_action_dim: int = 7, + num_heads: int = 8, + num_layers: int = 8, + dropout: float = 0.0, + num_inference_timesteps: int = 20, + num_categories: int = 1, + ): + super().__init__() + + if config is not None: + embed_dim = _cfgget(config, "embed_dim", embed_dim) + hidden_dim = _cfgget(config, "hidden_dim", hidden_dim) + action_dim = _cfgget(config, "action_dim", action_dim) + horizon = _cfgget(config, "horizon", horizon) + per_action_dim = _cfgget(config, "per_action_dim", per_action_dim) + num_heads = _cfgget(config, "num_heads", num_heads) + num_layers = _cfgget(config, "num_layers", num_layers) + dropout = _cfgget(config, "dropout", dropout) + num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps) + num_categories = _cfgget(config, "num_categories", num_categories) + self.config = config + else: + self.config = SimpleNamespace( + embed_dim=embed_dim, + hidden_dim=hidden_dim, + action_dim=action_dim, + horizon=horizon, + per_action_dim=per_action_dim, + num_heads=num_heads, + num_layers=num_layers, + dropout=dropout, + num_inference_timesteps=num_inference_timesteps, + num_categories=num_categories, + ) + + logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps) + self.embed_dim = embed_dim + self.horizon = horizon + self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim) + self.action_dim = _cfgget(self.config, "action_dim", action_dim) + + self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + embed_dim=embed_dim, + num_heads=num_heads, + hidden_dim=embed_dim * 4, + dropout=dropout, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = nn.LayerNorm(embed_dim) + self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim) + self.mlp_head = CategorySpecificMLP( + input_dim=embed_dim, + hidden_dim=hidden_dim, + output_dim=action_dim, + num_categories=num_categories, + ) + + self.state_encoder = None + state_dim = _cfgget(self.config, "state_dim") + if state_dim is not None: + state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim) + self.state_encoder = CategorySpecificMLP( + input_dim=state_dim, + hidden_dim=state_hidden, + output_dim=embed_dim, + num_categories=num_categories, + ) + + if horizon > 1: + self.action_encoder = MultiEmbodimentActionEncoder( + action_dim=self.per_action_dim, + embed_dim=embed_dim, + hidden_dim=embed_dim, + horizon=horizon, + num_categories=num_categories, + ) + self.single_action_proj = None + else: + self.action_encoder = None + self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim) + + def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor: + if self.horizon > 1 and self.action_encoder is not None: + return self.action_encoder(action_seq, embodiment_id) + if self.single_action_proj is None: + raise RuntimeError("single_action_proj is not initialized for horizon <= 1.") + return self.single_action_proj(action_seq) + + def _expand_action_mask( + self, + action_mask: torch.Tensor, + batch_size: int, + per_action_dim: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + if action_mask is None: + raise ValueError("action_mask must be provided for flow matching inference.") + + if action_mask.dim() == 2: + expected_last_dim = self.horizon * per_action_dim + if action_mask.shape == (batch_size, expected_last_dim): + expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim) + elif action_mask.shape == (batch_size, per_action_dim): + expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim) + else: + raise ValueError( + f"Expected action_mask shape {(batch_size, expected_last_dim)} or " + f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}" + ) + elif action_mask.dim() == 3: + expected_shape = (batch_size, self.horizon, per_action_dim) + if tuple(action_mask.shape) != expected_shape: + raise ValueError( + f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}" + ) + expanded_mask = action_mask + else: + raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}") + + return expanded_mask.to(device=device, dtype=dtype) + + def forward( + self, + fused_tokens: torch.Tensor, + state: torch.Tensor = None, + actions_gt: torch.Tensor = None, + embodiment_id: torch.LongTensor = None, + state_mask: torch.Tensor = None, + action_mask: torch.Tensor = None, + ): + if actions_gt is None: + return self.get_action( + fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask + ) + + batch_size = fused_tokens.size(0) + device = fused_tokens.device + if embodiment_id is None: + embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device) + + context_tokens = fused_tokens + if state is not None and self.state_encoder is not None: + state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1) + context_tokens = torch.cat([context_tokens, state_emb], dim=1) + + t = ( + torch.distributions.Beta(2, 2) + .sample((batch_size,)) + .clamp(0.02, 0.98) + .to(device) + .to(dtype=self.dtype) + ) + time_index = (t * 999).long().clamp_(0, 999) + time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype) + + actions_gt_seq = actions_gt + noise = torch.rand_like(actions_gt) * 2 - 1 + if action_mask is not None: + action_mask = action_mask.to(dtype=noise.dtype, device=noise.device) + if action_mask.shape != noise.shape: + raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}") + actions_gt_seq = actions_gt_seq * action_mask + noise = noise * action_mask + + if self.horizon > 1: + noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim) + else: + noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1) + t_broadcast = t.view(batch_size, 1, 1) + action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq + + action_tokens = self._project_actions(action_intermediate_seq, embodiment_id) + target_dtype = self.dtype + action_tokens = action_tokens.to(dtype=target_dtype) + context_tokens = context_tokens.to(dtype=target_dtype) + time_emb = time_emb.to(dtype=target_dtype) + + x = action_tokens + for block in self.transformer_blocks: + x = block(x, context_tokens, time_emb) + x = self.norm_out(x) + + if self.horizon > 1: + x_flat = x.reshape(batch_size, -1) + x_pooled = self.seq_pool_proj(x_flat) + else: + x_pooled = x.squeeze(1) + + pred_velocity = self.mlp_head(x_pooled, embodiment_id) + return pred_velocity, noise + + def get_action( + self, + fused_tokens: torch.Tensor, + state: torch.Tensor = None, + embodiment_id: torch.LongTensor = None, + action_mask: torch.Tensor = None, + ): + batch_size = fused_tokens.size(0) + device = fused_tokens.device + if embodiment_id is None: + embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device) + + context_tokens = fused_tokens + if state is not None and self.state_encoder is not None: + state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1) + context_tokens = torch.cat([context_tokens, state_emb], dim=1) + + action_dim_total = _cfgget(self.config, "action_dim", self.action_dim) + per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1)) + + action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1 + action_seq = ( + action.view(batch_size, self.horizon, per_action_dim) + if self.horizon > 1 + else action.view(batch_size, 1, per_action_dim) + ) + action_mask = self._expand_action_mask( + action_mask, + batch_size=batch_size, + per_action_dim=per_action_dim, + device=action_seq.device, + dtype=action_seq.dtype, + ) + action_seq = action_seq * action_mask + + target_dtype = self.dtype + context_tokens = context_tokens.to(dtype=target_dtype) + + num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32)) + if num_steps <= 0: + raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}") + dt = 1.0 / num_steps + + for i in range(num_steps): + t = i / num_steps + time_index = min(int(t * 999), 999) + time_emb = ( + self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype) + ) + time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1) + + action_seq = action_seq * action_mask + action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype) + time_emb = time_emb.to(dtype=target_dtype) + + x = action_tokens + for block in self.transformer_blocks: + x = block(x, context_tokens, time_emb) + x = self.norm_out(x) + + if self.horizon > 1: + x_flat = x.reshape(batch_size, -1) + x_pooled = self.seq_pool_proj(x_flat) + else: + x_pooled = x.squeeze(1) + + pred = self.mlp_head(x_pooled, embodiment_id) + action = action + dt * pred + action_seq = ( + action.view(batch_size, self.horizon, per_action_dim) + if self.horizon > 1 + else action.view(batch_size, 1, per_action_dim) + ) + + action_seq = action_seq * action_mask + return action_seq.reshape(batch_size, -1) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py new file mode 100644 index 000000000..4cf1b00e0 --- /dev/null +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -0,0 +1,366 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools +import logging +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torchvision.transforms.functional as TF +from PIL import Image +from torchvision.transforms.functional import to_pil_image + +from lerobot.utils.import_utils import require_package + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +IMG_CONTEXT_TOKEN = "" # nosec B105 +IMG_START_TOKEN = "" # nosec B105 +IMG_END_TOKEN = "" # nosec B105 + +logger = logging.getLogger(__name__) + + +def flash_attn_is_available() -> bool: + try: + import flash_attn # noqa: F401 + except ModuleNotFoundError: + return False + return True + + +@functools.lru_cache(maxsize=10000) +def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int): + aspect_ratio = orig_width / orig_height + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = orig_width * orig_height + for ratio in target_ratios: + target_ar = ratio[0] / ratio[1] + diff = abs(aspect_ratio - target_ar) + if diff < best_ratio_diff: + best_ratio_diff = diff + best_ratio = ratio + elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num) + target_width = image_size * ratio_w + target_height = image_size * ratio_h + blocks = ratio_w * ratio_h + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + processed_images.append(resized_img.crop(box)) + if use_thumbnail and len(processed_images) != 1: + processed_images.append(image.resize((image_size, image_size))) + return processed_images + + +class InternVL3Embedder(nn.Module): + def __init__( + self, + model_name="OpenGVLab/InternVL3-1B", + image_size=448, + device="cuda", + num_language_layers: int | None = 14, + model_dtype: str | torch.dtype = "bfloat16", + use_flash_attn: bool = True, + enable_gradient_checkpointing: bool = True, + gradient_checkpointing_use_reentrant: bool = False, + ): + super().__init__() + self._requested_device = device + self.image_size = image_size + self.num_language_layers = num_language_layers + self.max_text_length = 1024 + self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing) + self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant) + + require_package("transformers", extra="evo1") + from transformers import AutoModel, AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + if isinstance(model_dtype, str): + try: + model_dtype = getattr(torch, model_dtype) + except AttributeError as exc: + raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc + + resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available()) + if use_flash_attn and not resolved_use_flash_attn: + logger.warning("flash_attn is not installed. Falling back to standard attention.") + + self.model = AutoModel.from_pretrained( + model_name, + torch_dtype=model_dtype, + trust_remote_code=True, + use_flash_attn=resolved_use_flash_attn, + low_cpu_mem_usage=True, + _fast_init=False, + ).to(self._requested_device) + + if hasattr(self.model.language_model, "model"): + layers = self.model.language_model.model.layers + else: + layers = self.model.language_model.layers + if self.num_language_layers is not None: + layers = layers[: self.num_language_layers] + + if hasattr(self.model.language_model, "model"): + self.model.language_model.model.layers = torch.nn.ModuleList(layers) + else: + self.model.language_model.layers = torch.nn.ModuleList(layers) + self.model.language_model.lm_head = torch.nn.Identity() + + self._configure_memory_features() + self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + + def _configure_memory_features(self) -> None: + checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant} + + if not self.enable_gradient_checkpointing: + if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"): + self.model.vision_model.encoder.gradient_checkpointing = False + language_model = getattr(self.model, "language_model", None) + if language_model is not None: + if hasattr(language_model, "gradient_checkpointing_disable"): + language_model.gradient_checkpointing_disable() + elif hasattr(language_model, "gradient_checkpointing"): + language_model.gradient_checkpointing = False + if hasattr(language_model, "model"): + inner = language_model.model + if hasattr(inner, "gradient_checkpointing_disable"): + inner.gradient_checkpointing_disable() + elif hasattr(inner, "gradient_checkpointing"): + inner.gradient_checkpointing = False + return + + def _enable_ckpt(module: nn.Module | None) -> bool: + if module is None: + return False + if hasattr(module, "gradient_checkpointing_enable"): + try: + module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs) + except TypeError: + module.gradient_checkpointing_enable() + return True + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = True + return True + return False + + enabled_any = _enable_ckpt(self.model) + + if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"): + self.model.vision_model.encoder.gradient_checkpointing = True + enabled_any = True + + language_model = getattr(self.model, "language_model", None) + if language_model is not None: + enabled_any = _enable_ckpt(language_model) or enabled_any + if hasattr(language_model, "model"): + enabled_any = _enable_ckpt(language_model.model) or enabled_any + if hasattr(language_model, "config"): + language_model.config.use_cache = False + + if hasattr(self.model, "config"): + self.model.config.use_cache = False + if hasattr(self.model, "enable_input_require_grads"): + self.model.enable_input_require_grads() + + if enabled_any: + logger.info("Gradient checkpointing enabled for InternVL3 embedder.") + else: + logger.warning( + "Requested gradient checkpointing, but model does not expose checkpointing controls." + ) + + def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor: + if isinstance(image, torch.Tensor): + pil_image = to_pil_image(image.detach().cpu()) + else: + pil_image = image.convert("RGB") + tiles = dynamic_preprocess(pil_image, image_size=self.image_size) + tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to( + device=self.device, dtype=torch.bfloat16 + ) + mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) + std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) + return (tile_tensors - mean) / std + + def _preprocess_images( + self, + image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]], + ) -> tuple[torch.Tensor, list[list[int]]]: + pixel_values_list = [] + batch_num_tiles_list: list[list[int]] = [] + + for image_tensors in image_tensors_batch: + num_tiles_list: list[int] = [] + for image in image_tensors: + tiles = self._preprocess_single_image(image) + pixel_values_list.append(tiles) + num_tiles_list.append(int(tiles.shape[0])) + batch_num_tiles_list.append(num_tiles_list) + + if pixel_values_list: + pixel_values = torch.cat(pixel_values_list, dim=0) + else: + pixel_values = torch.empty( + 0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device + ) + return pixel_values, batch_num_tiles_list + + def _build_multimodal_prompts( + self, + batch_num_tiles_list: list[list[int]], + text_prompts: Sequence[str], + ) -> list[str]: + prompts = [] + for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True): + prompt_segments = [] + for i, tile_count in enumerate(num_tiles_list): + token_count = self.model.num_image_token * tile_count + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN + prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n") + prompts.append("".join(prompt_segments) + text_prompt.strip()) + return prompts + + def _prepare_and_fuse_embeddings( + self, + prompts: Sequence[str], + vit_embeds: torch.Tensor, + image_masks: torch.Tensor, + batch_num_tiles_list: list[list[int]], + ) -> tuple[torch.Tensor, torch.Tensor]: + untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"] + true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0) + if true_sequence_length > self.max_text_length: + logger.warning( + "InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s", + self.max_text_length, + true_sequence_length, + ) + + model_inputs = self.tokenizer( + list(prompts), + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_text_length, + ).to(self.device) + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + + img_token_mask = input_ids == self.img_context_token_id + input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone() + + batch_size, _, channels = input_embeds.shape + vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device) + tokens_per_tile = self.model.num_image_token + actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist() + + vit_idx = 0 + for batch_index in range(batch_size): + expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile + mask_b = img_token_mask[batch_index] + actual_vis_tokens = actual_vis_tokens_list[batch_index] + + item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens] + vit_idx += expected_vis_tokens + if actual_vis_tokens > 0: + if item_vit_embeds.shape[0] < actual_vis_tokens: + raise ValueError( + f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: " + f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}" + ) + input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens] + + current_token_idx = 0 + img_token_locations = torch.where(mask_b)[0] + for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]): + num_tokens_for_image = num_tiles * tokens_per_tile + if not bool(image_masks[batch_index, image_index].item()): + start_offset = current_token_idx + end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations)) + if start_offset < end_offset: + idxs = img_token_locations[start_offset:end_offset] + attention_mask[batch_index, idxs] = 0 + current_token_idx += num_tokens_for_image + + return input_embeds, attention_mask + + def get_fused_image_text_embedding_from_tensor_images( + self, + image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]], + image_masks: torch.Tensor, + text_prompts: Sequence[str], + return_cls_only: bool = True, + ): + pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch) + if pixel_values.shape[0] == 0: + logger.warning("InternVL3 received an empty image batch after preprocessing.") + hidden_size = getattr(self.model.config, "hidden_size", None) + if hidden_size is None and hasattr(self.model.language_model, "config"): + hidden_size = getattr(self.model.language_model.config, "hidden_size", None) + if hidden_size is None: + raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.") + empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32) + return empty + + prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts) + vit_embeds = self.model.extract_feature(pixel_values) + inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings( + prompts, + vit_embeds, + image_masks.to(device=self.device), + batch_num_tiles_list, + ) + + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + return_dict=True, + ) + fused_hidden = outputs.hidden_states[-1].to(torch.float32) + return fused_hidden[:, 0, :] if return_cls_only else fused_hidden + + @property + def device(self) -> torch.device: + return next(self.model.parameters()).device diff --git a/src/lerobot/policies/evo1/modeling_evo1.py b/src/lerobot/policies/evo1/modeling_evo1.py new file mode 100644 index 000000000..474fd52a5 --- /dev/null +++ b/src/lerobot/policies/evo1/modeling_evo1.py @@ -0,0 +1,419 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +from collections import deque +from contextlib import nullcontext +from pathlib import Path + +import torch +from torch import Tensor + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.evo1.configuration_evo1 import Evo1Config +from lerobot.policies.evo1.evo1_model import EVO1 +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +class EVO1Policy(PreTrainedPolicy): + config_class = Evo1Config + name = "evo1" + + def __init__(self, config: Evo1Config, **kwargs): + super().__init__(config) + config.validate_features() + + if len(config.image_features) > config.max_views: + raise ValueError( + f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}" + ) + + self.config = config + self.model = EVO1(self._build_model_config(config)) + self.model.set_finetune_flags() + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool | None = None, + **kwargs, + ) -> T: + if strict is None: + strict = not (config is not None and getattr(config, "training_stage", None) == "stage2") + return super().from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + config=config, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + strict=strict, + **kwargs, + ) + + @staticmethod + def _build_model_config(config: Evo1Config) -> dict: + return { + "device": config.device, + "return_cls_only": config.return_cls_only, + "vlm_name": config.vlm_model_name, + "vlm_num_layers": config.vlm_num_layers, + "vlm_dtype": config.vlm_dtype, + "use_flash_attn": config.use_flash_attn, + "action_head": config.action_head, + "action_horizon": config.chunk_size, + "per_action_dim": config.max_action_dim, + "state_dim": config.max_state_dim, + "embed_dim": config.embed_dim, + "hidden_dim": config.hidden_dim, + "state_hidden_dim": config.state_hidden_dim, + "num_heads": config.num_heads, + "num_layers": config.num_layers, + "dropout": config.dropout, + "num_inference_timesteps": config.num_inference_timesteps, + "num_categories": config.num_categories, + "enable_gradient_checkpointing": config.enable_gradient_checkpointing, + "gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant, + "finetune_vlm": config.finetune_vlm, + "finetune_language_model": config.finetune_language_model, + "finetune_vision_model": config.finetune_vision_model, + "finetune_action_head": config.finetune_action_head, + } + + @property + def _camera_keys(self) -> list[str]: + return list(self.config.image_features) + + @property + def _env_action_dim(self) -> int: + action_feature = self.config.action_feature + if action_feature is None: + return self.config.max_action_dim + return int(action_feature.shape[0]) + + @property + def _compute_dtype(self) -> torch.dtype: + return next(self.model.action_head.parameters()).dtype + + @property + def _training_compute_dtype(self) -> torch.dtype: + if str(self.config.device).startswith("cuda"): + return torch.bfloat16 + return self._compute_dtype + + @property + def _inference_compute_dtype(self) -> torch.dtype: + if str(self.config.device).startswith("cuda") and self.config.use_amp: + return torch.bfloat16 + return self._compute_dtype + + def get_optim_params(self) -> list[dict]: + decay, no_decay = [], [] + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + is_bias = name.endswith("bias") or ".bias" in name + is_norm = param.dim() == 1 or "norm" in name.lower() + if is_bias or is_norm: + no_decay.append(param) + else: + decay.append(param) + return [ + {"params": decay, "weight_decay": self.config.optimizer_weight_decay}, + {"params": no_decay, "weight_decay": 0.0}, + ] + + def reset(self): + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]: + prompts = batch.get(self.config.task_field) + if prompts is None and self.config.task_field != "task": + prompts = batch.get("task") + if prompts is None: + raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.") + if isinstance(prompts, str): + return [prompts] + if isinstance(prompts, (list, tuple)): + return [str(prompt) for prompt in prompts] + raise TypeError(f"Unsupported prompt batch type: {type(prompts)}") + + def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + if OBS_STATE not in batch: + raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.") + state = batch[OBS_STATE] + if state.dim() == 1: + state = state.unsqueeze(0) + elif state.dim() == 3: + state = state[:, -1] + elif state.dim() != 2: + raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}") + batch_size, state_dim = state.shape + if state_dim > self.config.max_state_dim: + raise ValueError( + f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}" + ) + explicit_mask = batch.get("state_mask") + if explicit_mask is not None: + if explicit_mask.dim() == 1: + explicit_mask = explicit_mask.unsqueeze(0) + elif explicit_mask.dim() == 3: + explicit_mask = explicit_mask[:, -1] + elif explicit_mask.dim() != 2: + raise ValueError( + f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}" + ) + if explicit_mask.shape != (batch_size, state_dim): + raise ValueError( + f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}" + ) + padded = torch.zeros( + batch_size, + self.config.max_state_dim, + dtype=state.dtype, + device=self.config.device, + ) + padded[:, :state_dim] = state.to(device=self.config.device) + mask = torch.zeros( + batch_size, + self.config.max_state_dim, + dtype=torch.bool, + device=self.config.device, + ) + if explicit_mask is None: + mask[:, :state_dim] = True + else: + mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool) + return padded.to(dtype=self._compute_dtype), mask + + def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + if ACTION not in batch: + raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.") + action = batch[ACTION] + if action.dim() == 2: + action = action.unsqueeze(1) + batch_size, horizon, action_dim = action.shape + if horizon != self.config.chunk_size: + raise ValueError( + f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}" + ) + if action_dim > self.config.max_action_dim: + raise ValueError( + f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}" + ) + explicit_mask = batch.get("action_mask") + if explicit_mask is not None: + if explicit_mask.dim() == 2: + if horizon == 1: + explicit_mask = explicit_mask.unsqueeze(1) + else: + raise ValueError( + f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}" + ) + elif explicit_mask.dim() != 3: + raise ValueError( + f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}" + ) + if explicit_mask.shape != (batch_size, horizon, action_dim): + raise ValueError( + "action_mask shape " + f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}" + ) + padded = torch.zeros( + batch_size, + horizon, + self.config.max_action_dim, + dtype=action.dtype, + device=self.config.device, + ) + padded[:, :, :action_dim] = action.to(device=self.config.device) + mask = torch.zeros( + batch_size, + horizon, + self.config.max_action_dim, + dtype=torch.bool, + device=self.config.device, + ) + if explicit_mask is None: + mask[:, :, :action_dim] = True + else: + mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool) + return padded.to(dtype=self._compute_dtype), mask + + def _prepare_inference_action_mask(self, batch_size: int) -> Tensor: + mask = torch.zeros( + batch_size, + self.config.max_action_dim, + dtype=torch.bool, + device=self.config.device, + ) + mask[:, : self._env_action_dim] = True + return mask + + def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor: + embodiment_ids = batch.get("embodiment_id") + if embodiment_ids is None and self.config.embodiment_id_field: + embodiment_ids = batch.get(self.config.embodiment_id_field) + if embodiment_ids is None: + return torch.full( + (batch_size,), + self.config.default_embodiment_id, + dtype=torch.long, + device=self.config.device, + ) + if embodiment_ids.dim() == 0: + embodiment_ids = embodiment_ids.unsqueeze(0) + elif embodiment_ids.dim() > 1: + embodiment_ids = embodiment_ids[:, -1] + return embodiment_ids.to(device=self.config.device, dtype=torch.long) + + def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]: + camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}.")) + if not camera_keys: + raise ValueError("EVO1 requires at least one visual observation feature.") + batch_size = batch[camera_keys[0]].shape[0] + image_batches: list[list[Tensor]] = [] + image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool) + + for batch_index in range(batch_size): + sample_images: list[Tensor] = [] + for camera_key in camera_keys[: self.config.max_views]: + image = batch[camera_key] + if image.dim() == 3: + image = image.unsqueeze(0) + elif image.dim() == 5: + image = image[:, -1] + elif image.dim() != 4: + raise ValueError( + f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}" + ) + sample_images.append(image[batch_index].detach().cpu()) + if not sample_images: + raise ValueError("EVO1 received a batch without any image tensor.") + while len(sample_images) < self.config.max_views: + sample_images.append(torch.zeros_like(sample_images[0])) + image_batches.append(sample_images[: self.config.max_views]) + image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True + + return image_batches, image_masks + + def _compute_fused_tokens( + self, + prompts: list[str], + image_batches: list[list[Tensor]], + image_masks: Tensor, + ) -> Tensor: + fused_tokens = self.model.get_vl_embeddings( + images=image_batches, + image_mask=image_masks, + prompt=prompts, + return_cls_only=self.config.return_cls_only, + ) + return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype) + + def _compute_masked_loss( + self, + pred_velocity: Tensor, + target_velocity: Tensor, + action_mask: Tensor, + reduction: str, + ) -> Tensor: + flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype) + sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2) + active = flat_mask.sum(dim=1).clamp_min(1.0) + per_sample_loss = sq_error.sum(dim=1) / active + if reduction == "none": + return per_sample_loss + if reduction != "mean": + raise ValueError(f"Unsupported reduction '{reduction}'") + return sq_error.sum() / active.sum() + + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + prompts = self._normalize_task_batch(batch) + image_batches, image_masks = self._collect_image_batches(batch) + states, _state_mask = self._prepare_state(batch) + actions_gt, action_mask = self._prepare_actions(batch) + fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks) + states = states.to(dtype=self._training_compute_dtype) + actions_gt = actions_gt.to(dtype=self._training_compute_dtype) + fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype) + embodiment_ids = self._get_embodiment_ids(batch, states.shape[0]) + + pred_velocity, noise = self.model( + fused_tokens, + state=states, + actions_gt=actions_gt, + action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype), + embodiment_ids=embodiment_ids, + ) + flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype) + target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask + loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction) + loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item() + return loss, { + "loss": loss_mean, + "active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()), + } + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + self.eval() + + prompts = self._normalize_task_batch(batch) + image_batches, image_masks = self._collect_image_batches(batch) + states, _state_mask = self._prepare_state(batch) + fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks) + states = states.to(dtype=self._inference_compute_dtype) + fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype) + embodiment_ids = self._get_embodiment_ids(batch, states.shape[0]) + action_mask = self._prepare_inference_action_mask(states.shape[0]) + + with ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if self.config.use_amp and str(self.config.device).startswith("cuda") + else nullcontext() + ): + actions = self.model( + fused_tokens, + state=states, + action_mask=action_mask, + embodiment_ids=embodiment_ids, + ) + actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim) + return actions[:, :, : self._env_action_dim] + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + self.eval() + if len(self._action_queue) == 0: + action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + self._action_queue.extend(action_chunk.transpose(0, 1)) + return self._action_queue.popleft() diff --git a/src/lerobot/policies/evo1/processor_evo1.py b/src/lerobot/policies/evo1/processor_evo1.py new file mode 100644 index 000000000..f1a162df1 --- /dev/null +++ b/src/lerobot/policies/evo1/processor_evo1.py @@ -0,0 +1,106 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch + +from lerobot.policies.evo1.configuration_evo1 import Evo1Config +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import ( + batch_to_transition, + create_transition, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.utils.constants import ( + ACTION, + DONE, + INFO, + OBS_PREFIX, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, + REWARD, + TRUNCATED, +) + + +def evo1_batch_to_transition(batch: dict[str, Any]): + transition = batch_to_transition(batch) + complementary_data = dict(transition.get("complementary_data") or {}) + reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO} + for key, value in batch.items(): + if key in reserved or key.startswith(OBS_PREFIX): + continue + complementary_data.setdefault(key, value) + return create_transition( + observation=transition.get("observation"), + action=transition.get("action"), + reward=transition.get("reward", 0.0), + done=transition.get("done", False), + truncated=transition.get("truncated", False), + info=transition.get("info", {}), + complementary_data=complementary_data, + ) + + +def make_evo1_pre_post_processors( + config: Evo1Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device=config.device), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + to_transition=evo1_batch_to_transition, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3609cc7c3..a511de67a 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features from .act.configuration_act import ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config +from .evo1.configuration_evo1 import Evo1Config from .groot.configuration_groot import GrootConfig from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config @@ -88,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x". + "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1". Returns: The policy class corresponding to the given name. @@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .eo1.modeling_eo1 import EO1Policy return EO1Policy + elif name == "evo1": + from .evo1.modeling_evo1 import EVO1Policy + + return EVO1Policy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -168,7 +173,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", - "smolvla", "wall_x". + "smolvla", "wall_x", "eo1", "evo1". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return WallXConfig(**kwargs) elif policy_type == "eo1": return EO1Config(**kwargs) + elif policy_type == "evo1": + return Evo1Config(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -413,6 +420,13 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, Evo1Config): + from .evo1.processor_evo1 import make_evo1_pre_post_processors + + processors = make_evo1_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) else: try: diff --git a/tests/policies/evo1/test_evo1.py b/tests/policies/evo1/test_evo1.py new file mode 100644 index 000000000..5bf170397 --- /dev/null +++ b/tests/policies/evo1/test_evo1.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn + +import lerobot.policies.evo1.modeling_evo1 as modeling_evo1 +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.evo1.configuration_evo1 import Evo1Config +from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead +from lerobot.policies.factory import get_policy_class, make_policy_config +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + +STATE_DIM = 4 +ACTION_DIM = 3 +MAX_STATE_DIM = 6 +MAX_ACTION_DIM = 5 +CHUNK_SIZE = 2 +EMBED_DIM = 8 + + +class DummyEVO1(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.action_head = nn.Linear(1, 1) + self.get_vl_embeddings_calls = 0 + + def set_finetune_flags(self): + return None + + def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False): + self.get_vl_embeddings_calls += 1 + return torch.ones(len(images), 4, EMBED_DIM) + + def forward( + self, + fused_tokens, + state=None, + actions_gt=None, + action_mask=None, + embodiment_ids=None, + ): + batch_size = fused_tokens.shape[0] + if actions_gt is None: + return torch.ones(batch_size, CHUNK_SIZE * MAX_ACTION_DIM) + pred_velocity = torch.zeros(batch_size, CHUNK_SIZE * MAX_ACTION_DIM) + noise = torch.zeros_like(actions_gt) + return pred_velocity, noise + + +def make_config(training_stage="stage1", **kwargs): + config_kwargs = { + "device": "cpu", + "vlm_model_name": "dummy-internvl3", + "training_stage": training_stage, + "chunk_size": CHUNK_SIZE, + "n_action_steps": 1, + "max_state_dim": MAX_STATE_DIM, + "max_action_dim": MAX_ACTION_DIM, + "max_views": 2, + "embed_dim": EMBED_DIM, + "hidden_dim": 16, + "state_hidden_dim": 16, + "num_heads": 2, + "num_layers": 1, + "num_inference_timesteps": 2, + "input_features": { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)), + f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)), + }, + "output_features": { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)), + }, + } + config_kwargs.update(kwargs) + return Evo1Config(**config_kwargs) + + +def make_batch(include_action=True): + batch = { + "task": ["pick the block", "place the block"], + OBS_STATE: torch.randn(2, STATE_DIM), + f"{OBS_IMAGES}.front": torch.rand(2, 3, 16, 16), + } + if include_action: + batch[ACTION] = torch.randn(2, CHUNK_SIZE, ACTION_DIM) + return batch + + +def test_evo1_factory_registration(): + cfg = make_policy_config( + "evo1", + device="cpu", + vlm_model_name="dummy-internvl3", + input_features={ + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)), + f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)), + }, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}, + ) + + assert isinstance(cfg, Evo1Config) + assert get_policy_class("evo1") is modeling_evo1.EVO1Policy + + +def test_evo1_stage_defaults_and_consistency(): + stage1 = make_config(training_stage="stage1") + assert (stage1.finetune_vlm, stage1.finetune_language_model, stage1.finetune_vision_model) == ( + False, + False, + False, + ) + assert stage1.finetune_action_head is True + + stage2 = make_config(training_stage="stage2") + assert (stage2.finetune_vlm, stage2.finetune_language_model, stage2.finetune_vision_model) == ( + True, + True, + True, + ) + assert stage2.finetune_action_head is True + + explicit_off = make_config( + training_stage="stage2", + finetune_vlm=False, + finetune_language_model=False, + finetune_vision_model=False, + finetune_action_head=False, + ) + assert ( + explicit_off.finetune_vlm, + explicit_off.finetune_language_model, + explicit_off.finetune_vision_model, + ) == ( + False, + False, + False, + ) + assert explicit_off.finetune_action_head is False + + try: + make_config(training_stage="stage2", finetune_vlm=True, finetune_language_model=False) + except ValueError as exc: + assert "Inconsistent EVO1 finetune config" in str(exc) + else: + raise AssertionError("Expected inconsistent finetune config to raise ValueError") + + +def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch): + monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) + policy = modeling_evo1.EVO1Policy(make_config()) + + loss, metrics = policy.forward(make_batch(include_action=True)) + assert loss.ndim == 0 + assert torch.isfinite(loss) + assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE + assert policy.model.get_vl_embeddings_calls == 1 + + action_chunk = policy.predict_action_chunk(make_batch(include_action=False)) + assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM) + + policy.reset() + selected = policy.select_action(make_batch(include_action=False)) + assert selected.shape == (2, ACTION_DIM) + + +def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch): + monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) + config = make_config(chunk_size=1, n_action_steps=1) + policy = modeling_evo1.EVO1Policy(config) + batch = make_batch(include_action=True) + batch[ACTION] = torch.randn(2, ACTION_DIM) + batch["action_mask"] = torch.ones(2, ACTION_DIM, dtype=torch.bool) + + actions, action_mask = policy._prepare_actions(batch) + + assert actions.shape == (2, 1, MAX_ACTION_DIM) + assert action_mask.shape == (2, 1, MAX_ACTION_DIM) + assert action_mask[:, :, :ACTION_DIM].all() + assert not action_mask[:, :, ACTION_DIM:].any() + + +def test_flowmatching_dict_config_enables_state_encoder_for_horizon_one(): + head = FlowmatchingActionHead( + config={ + "embed_dim": EMBED_DIM, + "hidden_dim": 16, + "action_dim": ACTION_DIM, + "horizon": 1, + "per_action_dim": ACTION_DIM, + "num_heads": 2, + "num_layers": 1, + "num_inference_timesteps": 2, + "state_dim": STATE_DIM, + "state_hidden_dim": 16, + "num_categories": 1, + } + ) + + assert head.state_encoder is not None + pred_velocity, noise = head( + torch.randn(2, 4, EMBED_DIM), + state=torch.randn(2, STATE_DIM), + actions_gt=torch.randn(2, 1, ACTION_DIM), + action_mask=torch.ones(2, 1, ACTION_DIM, dtype=torch.bool), + ) + + assert pred_velocity.shape == (2, ACTION_DIM) + assert noise.shape == (2, 1, ACTION_DIM)