mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
* add molmoact2 policy * add apache headers to molmoact2 files * simplify molmoact2 package imports * align molmoact2 feature validation with eo pattern * remove molmoact2 processor override from factory * guard molmoact2 transformers imports * guard molmoact2 processor transformers import * add scipy dependency to molmoact2 extra * use a single molmoact2 action queue * move molmoact2 config logic into config * fix molmoact2 hf image key resolution * load molmoact2 without remote code * lazy import molmoact2 scipy * format molmoact2 files * skip molmoact2 tests without optional deps * fix molmoact2 pre-commit checks * validate molmoact2 gripper range
565 lines
22 KiB
Python
565 lines
22 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# ruff: noqa
|
|
|
|
"""Image processor class for MolmoAct2"""
|
|
|
|
from typing import Optional, Union
|
|
import numpy as np
|
|
import einops
|
|
import torch
|
|
import torchvision.transforms
|
|
|
|
from transformers.image_utils import (
|
|
IMAGENET_STANDARD_MEAN,
|
|
IMAGENET_STANDARD_STD,
|
|
ImageInput,
|
|
PILImageResampling,
|
|
make_flat_list_of_images,
|
|
valid_images,
|
|
to_numpy_array,
|
|
)
|
|
from transformers.image_transforms import convert_to_rgb
|
|
from transformers.processing_utils import ImagesKwargs
|
|
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
|
from transformers.utils import logging
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
from transformers.utils import TensorType, logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def normalize_image(
|
|
image: np.ndarray,
|
|
image_mean: list[float],
|
|
image_std: list[float],
|
|
) -> np.ndarray:
|
|
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
|
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
|
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
|
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
|
return image
|
|
|
|
|
|
def resize_image(
|
|
image: np.ndarray,
|
|
desired_output_size: list[int],
|
|
resample: PILImageResampling,
|
|
) -> np.ndarray:
|
|
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
|
dtype = image.dtype
|
|
if torch.is_floating_point(image):
|
|
in_min = 0.0
|
|
in_max = 1.0
|
|
resized = torchvision.transforms.Resize(
|
|
desired_output_size,
|
|
resample,
|
|
antialias=False,
|
|
)(image)
|
|
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
|
else:
|
|
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
|
image.dtype
|
|
)
|
|
in_min = 0.0
|
|
in_max = 255.0
|
|
resized = torchvision.transforms.Resize(
|
|
desired_output_size,
|
|
resample,
|
|
antialias=False,
|
|
)(image)
|
|
resized = torch.clip(resized, 0, 255).to(dtype)
|
|
|
|
resized = resized.to(torch.float32)
|
|
resized = (resized - in_min) / (in_max - in_min)
|
|
|
|
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
|
|
|
return resized
|
|
|
|
|
|
def select_tiling(h, w, patch_size, max_num_crops):
|
|
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
|
original_size = np.stack([h, w]) # [1, 2]
|
|
original_res = h * w
|
|
tilings = []
|
|
for i in range(1, max_num_crops + 1):
|
|
for j in range(1, max_num_crops + 1):
|
|
if i * j <= max_num_crops:
|
|
tilings.append((i, j))
|
|
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
|
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
|
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
|
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
|
|
|
# How much we would need to scale the image to fit exactly in each tiling
|
|
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
|
|
|
# The original size can be zero in rare cases if the image is smaller than the margin
|
|
# In those cases letting the scale become infinite means the tiling is based on the
|
|
# other side, or falls back to the smallest tiling
|
|
with np.errstate(divide="ignore"):
|
|
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
|
|
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
|
if np.all(required_scale < 1):
|
|
# We are forced to downscale, so try to minimize the amount of downscaling
|
|
ix = np.argmax(required_scale)
|
|
else:
|
|
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
|
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
|
ix = np.argmin(required_scale)
|
|
return candidate_tilings[ix]
|
|
|
|
|
|
def build_resized_image(
|
|
image: np.ndarray,
|
|
base_image_input_size: list[int],
|
|
resample: PILImageResampling,
|
|
image_mean: list[float],
|
|
image_std: list[float],
|
|
image_patch_size: int,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
resized = resize_image(
|
|
image,
|
|
base_image_input_size,
|
|
resample,
|
|
)
|
|
resized = normalize_image(resized, image_mean, image_std)
|
|
if len(resized.shape) == 3:
|
|
resized = np.expand_dims(resized, 0)
|
|
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
crop_patch_h = base_image_input_size[0] // image_patch_size
|
|
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
|
return resized, resize_idx
|
|
|
|
|
|
def build_overlapping_crops(
|
|
image: np.ndarray,
|
|
max_crops: int,
|
|
overlap_margins: list[int],
|
|
base_image_input_size: list[int],
|
|
resample: PILImageResampling,
|
|
image_mean: list[float],
|
|
image_std: list[float],
|
|
image_patch_size: int,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Decompose an image into a set of overlapping crops
|
|
|
|
:return crop_arr: [n_crops, h, w, 3] The crops
|
|
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
|
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
|
"""
|
|
original_image_h, original_image_w = image.shape[:2]
|
|
crop_size = base_image_input_size[0]
|
|
assert base_image_input_size[0] == base_image_input_size[1]
|
|
|
|
left_margin, right_margin = overlap_margins
|
|
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
|
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
|
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
|
crop_window_size = crop_window_patches * image_patch_size
|
|
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
crop_patch_h = base_image_input_size[0] // image_patch_size
|
|
original_image_h, original_image_w = image.shape[:2]
|
|
crop_size = base_image_input_size[0]
|
|
|
|
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
|
# as if we had an image without the margins and were using a crop size without the margins
|
|
tiling = select_tiling(
|
|
original_image_h - total_margin_pixels,
|
|
original_image_w - total_margin_pixels,
|
|
crop_window_size,
|
|
max_crops,
|
|
)
|
|
|
|
src = resize_image(
|
|
image,
|
|
[
|
|
tiling[0] * crop_window_size + total_margin_pixels,
|
|
tiling[1] * crop_window_size + total_margin_pixels,
|
|
],
|
|
resample,
|
|
)
|
|
src = normalize_image(src, image_mean, image_std)
|
|
|
|
# Now we have to split the image into crops, and track what patches came from
|
|
# where in `patch_idx_arr`
|
|
n_crops = tiling[0] * tiling[1]
|
|
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
|
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
|
on_crop = 0
|
|
for i in range(tiling[0]):
|
|
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
|
# which results in overlapping crop windows
|
|
y0 = i * crop_window_size
|
|
for j in range(tiling[1]):
|
|
x0 = j * crop_window_size
|
|
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
|
|
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
|
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
|
|
|
# Mask out idx that are in the overlap region
|
|
if i != 0:
|
|
patch_idx[:left_margin, :] = -1
|
|
if j != 0:
|
|
patch_idx[:, :left_margin] = -1
|
|
if i != tiling[0] - 1:
|
|
patch_idx[-right_margin:, :] = -1
|
|
if j != tiling[1] - 1:
|
|
patch_idx[:, -right_margin:] = -1
|
|
patch_idx_arr[on_crop] = patch_idx
|
|
on_crop += 1
|
|
|
|
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
|
# so it is ordered left-to-right order
|
|
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
|
|
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
|
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
|
|
|
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
|
# to the correct patch it should come from in `crop_arr`
|
|
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
|
src.shape[0] // image_patch_size,
|
|
src.shape[1] // image_patch_size,
|
|
)
|
|
return crop_arr, patch_idx_arr
|
|
|
|
|
|
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
|
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
|
if len(array.shape) == 3:
|
|
n_crops, h, w = array.shape
|
|
h_patches = h // patch_size
|
|
w_patches = w // patch_size
|
|
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
|
array = np.transpose(array, [0, 1, 3, 2, 4])
|
|
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
|
return array
|
|
else:
|
|
n_crops, h, w, c = array.shape
|
|
h_patches = h // patch_size
|
|
w_patches = w // patch_size
|
|
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
|
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
|
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
|
return array
|
|
|
|
|
|
def arange_for_pooling(
|
|
idx_arr: np.ndarray,
|
|
pool_h: int,
|
|
pool_w: int,
|
|
) -> np.ndarray:
|
|
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
|
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
|
idx_arr = np.pad(
|
|
idx_arr,
|
|
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
|
mode="constant",
|
|
constant_values=-1,
|
|
)
|
|
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
|
|
|
|
|
def image_to_patches_and_grids(
|
|
image: np.ndarray,
|
|
max_crops: int,
|
|
overlap_margins: list[int],
|
|
base_image_input_size: list[int],
|
|
resample: PILImageResampling,
|
|
image_mean: list[float],
|
|
image_std: list[float],
|
|
image_patch_size: int,
|
|
image_pooling_w: int,
|
|
image_pooling_h: int,
|
|
crop_mode: str = "overlap-and-resize-c2",
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
"""
|
|
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
|
:return crops, the image crops to processes with the ViT
|
|
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
|
patches in `crops` to pool for that token, masked with -1
|
|
"""
|
|
if isinstance(base_image_input_size, int):
|
|
base_image_input_size = (base_image_input_size, base_image_input_size)
|
|
|
|
base_image_input_d = image_patch_size
|
|
pooling_w = image_pooling_w
|
|
pooling_h = image_pooling_h
|
|
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
|
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
|
|
|
if crop_mode == "resize":
|
|
resized, resize_idx = build_resized_image(
|
|
image,
|
|
base_image_input_size,
|
|
resample,
|
|
image_mean,
|
|
image_std,
|
|
image_patch_size,
|
|
)
|
|
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
|
resized_h, resized_w = resize_idx.shape[:2]
|
|
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
|
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
|
return (
|
|
np.stack(image_grid, 0),
|
|
batch_pixels_to_patches(resized, image_patch_size),
|
|
resize_idx,
|
|
)
|
|
|
|
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
|
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
|
|
|
crop_arr, patch_idx_arr = build_overlapping_crops(
|
|
image,
|
|
max_crops,
|
|
overlap_margins,
|
|
base_image_input_size,
|
|
resample,
|
|
image_mean,
|
|
image_std,
|
|
image_patch_size,
|
|
)
|
|
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
|
h, w = pooling_idx.shape[:2]
|
|
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
|
|
|
# Finally do the same for the global image
|
|
resized, resize_idx = build_resized_image(
|
|
image,
|
|
base_image_input_size,
|
|
resample,
|
|
image_mean,
|
|
image_std,
|
|
image_patch_size,
|
|
)
|
|
crop_arr = np.concatenate([resized, crop_arr], 0)
|
|
|
|
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
|
resized_h, resized_w = resize_idx.shape[:2]
|
|
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
|
|
|
# Global image goes first, so the order of patches in previous crops gets increased
|
|
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
|
|
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
|
image_grid = [np.array([resized_h, resized_w, h, w])]
|
|
|
|
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
|
|
|
|
|
|
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
|
max_crops: int | None
|
|
overlap_margins: list[int] | None
|
|
crop_mode: str | None
|
|
patch_size: int | None
|
|
pooling_size: list[int] | None
|
|
|
|
|
|
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
|
r"""
|
|
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
|
|
|
Args:
|
|
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
|
Size of the image after resizing.
|
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
|
Resampling filter to use when resizing the image.
|
|
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
|
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
|
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
|
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
|
Whether to convert the image to RGB.
|
|
max_crops (`int`, *optional*, defaults to `8`):
|
|
Maximum number of crops to use per image.
|
|
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
|
Overlap margins to use.
|
|
patch_size (`int`, *optional*, defaults to 14):
|
|
The spatial patch size of the vision encoder.
|
|
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
|
The pooling size of the vision adapter.
|
|
"""
|
|
|
|
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
|
|
|
def __init__(
|
|
self,
|
|
size: dict[str, int] | None = None,
|
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
|
image_mean: float | list[float] | None = None,
|
|
image_std: float | list[float] | None = None,
|
|
do_convert_rgb: bool = True,
|
|
max_crops: int = 8,
|
|
overlap_margins: list[int] = [4, 4],
|
|
crop_mode: str = "overlap-and-resize-c2",
|
|
patch_size: int = 14,
|
|
pooling_size: list[int] = [2, 2],
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
size = size if size is not None else {"height": 378, "width": 378}
|
|
size = get_size_dict(size, default_to_square=True)
|
|
self.size = size
|
|
|
|
self.resample = resample
|
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
self.do_convert_rgb = do_convert_rgb
|
|
|
|
self.max_crops = max_crops
|
|
self.overlap_margins = overlap_margins
|
|
self.crop_mode = crop_mode
|
|
self.patch_size = patch_size
|
|
self.pooling_size = pooling_size
|
|
|
|
def preprocess(
|
|
self,
|
|
images: ImageInput,
|
|
size: dict[str, int] | None = None,
|
|
resample: PILImageResampling | None = None,
|
|
image_mean: float | list[float] | None = None,
|
|
image_std: float | list[float] | None = None,
|
|
do_convert_rgb: bool | None = None,
|
|
max_crops: int | None = None,
|
|
overlap_margins: list[int] | None = None,
|
|
crop_mode: str | None = None,
|
|
patch_size: int | None = None,
|
|
pooling_size: list[int] | None = None,
|
|
return_tensors: str | TensorType | None = None,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
"""
|
|
Args:
|
|
images (`ImageInput`):
|
|
Image to preprocess.
|
|
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
|
Size of the image after resizing.
|
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
|
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
|
has an effect if `do_resize` is set to `True`.
|
|
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
|
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
|
`True`.
|
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
|
Whether to convert the image to RGB.
|
|
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
|
Maximum number of crops to use per image.
|
|
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
|
Overlap margins to use.
|
|
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
|
The spatial patch size of the vision encoder.
|
|
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
|
The pooling size of the vision adapter.
|
|
return_tensors (`str` or `TensorType`, *optional*):
|
|
The type of tensors to return. Can be one of:
|
|
- Unset: Return a list of `np.ndarray`.
|
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
|
|
|
Returns:
|
|
A `BatchFeature` containing the following keys:
|
|
- `pixel_values`: The preprocessed images.
|
|
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
|
- `image_grids`: The image grids.
|
|
- `image_num_crops`: The number of crops for each image.
|
|
"""
|
|
if size is not None:
|
|
if "height" not in size or "width" not in size:
|
|
raise ValueError("size must contain 'height' and 'width' keys.")
|
|
else:
|
|
size = {**self.size}
|
|
|
|
base_image_input_size = [size["height"], size["width"]]
|
|
|
|
resample = resample or self.resample
|
|
image_mean = image_mean or self.image_mean
|
|
image_std = image_std or self.image_std
|
|
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
|
|
|
max_crops = max_crops or self.max_crops
|
|
overlap_margins = overlap_margins or self.overlap_margins
|
|
crop_mode = crop_mode or self.crop_mode
|
|
patch_size = patch_size or self.patch_size
|
|
pooling_size = pooling_size or self.pooling_size
|
|
|
|
image_pooling_h, image_pooling_w = pooling_size
|
|
|
|
if images is not None:
|
|
images = self.fetch_images(images)
|
|
images = make_flat_list_of_images(images)
|
|
|
|
if images is not None and not valid_images(images):
|
|
raise ValueError(
|
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
|
)
|
|
|
|
if do_convert_rgb:
|
|
images = [convert_to_rgb(image) for image in images]
|
|
|
|
# All transformations expect numpy arrays.
|
|
images = [to_numpy_array(image) for image in images]
|
|
|
|
data = {}
|
|
if images is not None:
|
|
batch_grids = []
|
|
batch_crops = []
|
|
batch_pooled_patches_idx = []
|
|
batch_num_crops = []
|
|
|
|
for image in images:
|
|
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
|
image,
|
|
max_crops,
|
|
overlap_margins,
|
|
base_image_input_size,
|
|
resample,
|
|
image_mean,
|
|
image_std,
|
|
patch_size,
|
|
image_pooling_w,
|
|
image_pooling_h,
|
|
crop_mode,
|
|
)
|
|
batch_grids.append(image_grid)
|
|
batch_crops.append(crops)
|
|
batch_pooled_patches_idx.append(pooled_idx)
|
|
batch_num_crops.append(crops.shape[0])
|
|
|
|
pixel_values = np.concatenate(batch_crops, 0)
|
|
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
|
image_grids = np.concatenate(batch_grids, 0)
|
|
image_num_crops = np.array(batch_num_crops)
|
|
|
|
data.update(
|
|
pixel_values=pixel_values,
|
|
image_token_pooling=image_token_pooling,
|
|
image_grids=image_grids,
|
|
image_num_crops=image_num_crops,
|
|
)
|
|
|
|
return BatchFeature(data, tensor_type=return_tensors)
|
|
|
|
|
|
MolmoAct2ImageProcessor.register_for_auto_class()
|