mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
revert to openpi transformer replace python 3.11
This commit is contained in:
@@ -517,6 +517,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
|
||||
msg = """transformers_replace is not installed correctly.
|
||||
Please install it with `pip install transformers==4.53.2`
|
||||
and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
|
||||
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
|
||||
@@ -45,25 +45,6 @@ from .configuration_gemma import GemmaConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
@@ -374,9 +355,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: (
|
||||
None | tuple[torch.Tensor, torch.Tensor]
|
||||
) = None, # necessary, but kept here for BC
|
||||
position_embeddings: None
|
||||
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
@@ -410,7 +390,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
@@ -441,7 +421,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
@@ -468,7 +448,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -540,7 +520,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841
|
||||
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
@@ -586,7 +566,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
@@ -620,7 +600,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -704,7 +684,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
@@ -735,7 +715,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -811,7 +791,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@@ -836,7 +816,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -39,27 +39,8 @@ from .configuration_paligemma import PaliGemmaConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Paligemma outputs, with hidden states and attentions.
|
||||
"""
|
||||
@@ -81,7 +62,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PaliGemma causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
@@ -124,7 +105,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = PaliGemmaConfig
|
||||
base_model_prefix = ""
|
||||
@@ -150,7 +131,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
@@ -277,7 +258,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@@ -336,7 +317,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id worth PAD if the image token if OOV, to avoid index-errors
|
||||
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
@@ -409,7 +390,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
@@ -450,7 +431,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.model.get_image_features(pixel_values)
|
||||
|
||||
# Make modules available conditional class for BC
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
@@ -464,7 +445,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
import transformers
|
||||
|
||||
|
||||
def check_whether_transformers_replace_is_installed_correctly():
|
||||
return transformers.__version__ == "4.53.2"
|
||||
@@ -37,25 +37,6 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
@@ -78,7 +59,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
@@ -152,7 +133,7 @@ def default_flax_embed_init(tensor):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
||||
"""
|
||||
@@ -171,7 +152,7 @@ class SiglipVisionModelOutput(ModelOutput):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
||||
"""
|
||||
@@ -190,7 +171,7 @@ class SiglipTextModelOutput(ModelOutput):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
|
||||
class SiglipOutput(ModelOutput):
|
||||
r"""
|
||||
@@ -502,7 +483,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class SiglipPreTrainedModel(PreTrainedModel):
|
||||
config_class = SiglipConfig
|
||||
base_model_prefix = "siglip"
|
||||
@@ -663,7 +644,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
@@ -715,7 +696,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The text model from SigLIP without any head or projection on top.
|
||||
"""
|
||||
@@ -736,7 +717,7 @@ class SiglipTextModel(SiglipPreTrainedModel):
|
||||
self.text_model.embeddings.token_embedding = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
@@ -785,7 +766,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
@@ -853,7 +834,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The vision model from SigLIP without any head or projection on top.
|
||||
"""
|
||||
@@ -874,7 +855,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
@@ -911,7 +892,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class SiglipModel(SiglipPreTrainedModel):
|
||||
config_class = SiglipConfig
|
||||
|
||||
@@ -947,7 +928,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
@@ -995,7 +976,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
|
||||
return pooled_output
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
@@ -1047,7 +1028,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
return pooled_output
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -1150,7 +1131,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
|
||||
the patch tokens) e.g. for ImageNet.
|
||||
@@ -1180,7 +1161,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
|
||||
@@ -518,6 +518,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
|
||||
msg = """transformers_replace is not installed correctly.
|
||||
Please install it with `pip install transformers==4.53.2`
|
||||
and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
|
||||
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
|
||||
@@ -45,25 +45,6 @@ from .configuration_gemma import GemmaConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
@@ -374,9 +355,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: (
|
||||
None | tuple[torch.Tensor, torch.Tensor]
|
||||
) = None, # necessary, but kept here for BC
|
||||
position_embeddings: None
|
||||
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
@@ -410,7 +390,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
@@ -441,7 +421,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
@@ -468,7 +448,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -540,7 +520,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841
|
||||
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
@@ -586,7 +566,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
@@ -620,7 +600,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -704,7 +684,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
@@ -735,7 +715,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -811,7 +791,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@@ -836,7 +816,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -39,27 +39,8 @@ from .configuration_paligemma import PaliGemmaConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Paligemma outputs, with hidden states and attentions.
|
||||
"""
|
||||
@@ -81,7 +62,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PaliGemma causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
@@ -124,7 +105,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = PaliGemmaConfig
|
||||
base_model_prefix = ""
|
||||
@@ -150,7 +131,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
@@ -277,7 +258,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@@ -336,7 +317,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id worth PAD if the image token if OOV, to avoid index-errors
|
||||
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
@@ -409,7 +390,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
@@ -450,7 +431,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.model.get_image_features(pixel_values)
|
||||
|
||||
# Make modules available conditional class for BC
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
@@ -464,7 +445,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
import transformers
|
||||
|
||||
|
||||
def check_whether_transformers_replace_is_installed_correctly():
|
||||
return transformers.__version__ == "4.53.2"
|
||||
@@ -37,25 +37,6 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
|
||||
def safe_auto_docstring(func=None, **kwargs):
|
||||
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
|
||||
|
||||
def decorator(f):
|
||||
try:
|
||||
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
|
||||
except (AttributeError, TypeError):
|
||||
# If auto_docstring fails due to UnionType, just return the function unchanged
|
||||
return f
|
||||
|
||||
if func is None:
|
||||
# Called with arguments, return the decorator
|
||||
return decorator
|
||||
else:
|
||||
# Called without arguments, apply directly
|
||||
return decorator(func)
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
@@ -78,7 +59,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
@@ -152,7 +133,7 @@ def default_flax_embed_init(tensor):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
||||
"""
|
||||
@@ -171,7 +152,7 @@ class SiglipVisionModelOutput(ModelOutput):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
||||
"""
|
||||
@@ -190,7 +171,7 @@ class SiglipTextModelOutput(ModelOutput):
|
||||
|
||||
|
||||
@dataclass
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
|
||||
class SiglipOutput(ModelOutput):
|
||||
r"""
|
||||
@@ -502,7 +483,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class SiglipPreTrainedModel(PreTrainedModel):
|
||||
config_class = SiglipConfig
|
||||
base_model_prefix = "siglip"
|
||||
@@ -663,7 +644,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
@@ -715,7 +696,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The text model from SigLIP without any head or projection on top.
|
||||
"""
|
||||
@@ -736,7 +717,7 @@ class SiglipTextModel(SiglipPreTrainedModel):
|
||||
self.text_model.embeddings.token_embedding = value
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
@@ -785,7 +766,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
@@ -853,7 +834,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The vision model from SigLIP without any head or projection on top.
|
||||
"""
|
||||
@@ -874,7 +855,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
@@ -911,7 +892,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
class SiglipModel(SiglipPreTrainedModel):
|
||||
config_class = SiglipConfig
|
||||
|
||||
@@ -947,7 +928,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
@@ -995,7 +976,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
|
||||
return pooled_output
|
||||
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
@@ -1047,7 +1028,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
return pooled_output
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -1150,7 +1131,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@safe_auto_docstring(
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
|
||||
the patch tokens) e.g. for ImageNet.
|
||||
@@ -1180,7 +1161,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@safe_auto_docstring
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
|
||||
Reference in New Issue
Block a user