use safeauto_docstring

This commit is contained in:
Pepijn
2025-09-12 20:19:16 +02:00
parent f840d2e006
commit 7a03223693
6 changed files with 178 additions and 64 deletions

View File

@@ -45,6 +45,25 @@ from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
@@ -390,7 +409,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
return outputs
@auto_docstring
@safe_auto_docstring
class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
@@ -421,7 +440,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@auto_docstring
@safe_auto_docstring
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
@@ -448,7 +467,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -566,7 +585,7 @@ class GemmaModel(GemmaPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring
@safe_auto_docstring
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -600,7 +619,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
return self.model
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -684,7 +703,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Gemma Model transformer with a sequence classification head on top (linear layer).
@@ -715,7 +734,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -791,7 +810,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
)
@auto_docstring
@safe_auto_docstring
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -816,7 +835,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -39,8 +39,27 @@ from .configuration_paligemma import PaliGemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for Paligemma outputs, with hidden states and attentions.
"""
@@ -62,7 +81,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for PaliGemma causal language model (or autoregressive) outputs.
"""
@@ -105,7 +124,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
return hidden_states
@auto_docstring
@safe_auto_docstring
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = ""
@@ -131,7 +150,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
@@ -258,7 +277,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
return image_features
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
@@ -390,7 +409,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
@@ -445,7 +464,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
return self.model.multi_modal_projector
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -37,6 +37,25 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@@ -133,7 +152,7 @@ def default_flax_embed_init(tensor):
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
"""
@@ -152,7 +171,7 @@ class SiglipVisionModelOutput(ModelOutput):
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for text model's outputs that also contains a pooling of the last hidden states.
"""
@@ -171,7 +190,7 @@ class SiglipTextModelOutput(ModelOutput):
@dataclass
@auto_docstring
@safe_auto_docstring
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class SiglipOutput(ModelOutput):
r"""
@@ -483,7 +502,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer):
return outputs
@auto_docstring
@safe_auto_docstring
class SiglipPreTrainedModel(PreTrainedModel):
config_class = SiglipConfig
base_model_prefix = "siglip"
@@ -644,7 +663,7 @@ class SiglipTextTransformer(nn.Module):
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -696,7 +715,7 @@ class SiglipTextTransformer(nn.Module):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The text model from SigLIP without any head or projection on top.
"""
@@ -717,7 +736,7 @@ class SiglipTextModel(SiglipPreTrainedModel):
self.text_model.embeddings.token_embedding = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -766,7 +785,7 @@ class SiglipVisionTransformer(nn.Module):
self.head = SiglipMultiheadAttentionPoolingHead(config)
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
pixel_values,
@@ -834,7 +853,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0]
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The vision model from SigLIP without any head or projection on top.
"""
@@ -855,7 +874,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
return self.vision_model.embeddings.patch_embedding
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
pixel_values,
@@ -892,7 +911,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
)
@auto_docstring
@safe_auto_docstring
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
@@ -928,7 +947,7 @@ class SiglipModel(SiglipPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
@safe_auto_docstring
def get_text_features(
self,
input_ids: torch.Tensor | None = None,
@@ -976,7 +995,7 @@ class SiglipModel(SiglipPreTrainedModel):
return pooled_output
@auto_docstring
@safe_auto_docstring
def get_image_features(
self,
pixel_values: torch.FloatTensor | None = None,
@@ -1028,7 +1047,7 @@ class SiglipModel(SiglipPreTrainedModel):
return pooled_output
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -1131,7 +1150,7 @@ class SiglipModel(SiglipPreTrainedModel):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
@@ -1161,7 +1180,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
self.post_init()
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,

View File

@@ -45,6 +45,25 @@ from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
@@ -390,7 +409,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
return outputs
@auto_docstring
@safe_auto_docstring
class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
@@ -421,7 +440,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@auto_docstring
@safe_auto_docstring
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
@@ -448,7 +467,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -566,7 +585,7 @@ class GemmaModel(GemmaPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring
@safe_auto_docstring
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -600,7 +619,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
return self.model
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -684,7 +703,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Gemma Model transformer with a sequence classification head on top (linear layer).
@@ -715,7 +734,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -791,7 +810,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
)
@auto_docstring
@safe_auto_docstring
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -816,7 +835,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -39,8 +39,27 @@ from .configuration_paligemma import PaliGemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for Paligemma outputs, with hidden states and attentions.
"""
@@ -62,7 +81,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for PaliGemma causal language model (or autoregressive) outputs.
"""
@@ -105,7 +124,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
return hidden_states
@auto_docstring
@safe_auto_docstring
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = ""
@@ -131,7 +150,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
@@ -258,7 +277,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
return image_features
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
@@ -390,7 +409,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
@@ -445,7 +464,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
return self.model.multi_modal_projector
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -37,6 +37,25 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@@ -133,7 +152,7 @@ def default_flax_embed_init(tensor):
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
"""
@@ -152,7 +171,7 @@ class SiglipVisionModelOutput(ModelOutput):
@dataclass
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
Base class for text model's outputs that also contains a pooling of the last hidden states.
"""
@@ -171,7 +190,7 @@ class SiglipTextModelOutput(ModelOutput):
@dataclass
@auto_docstring
@safe_auto_docstring
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class SiglipOutput(ModelOutput):
r"""
@@ -483,7 +502,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer):
return outputs
@auto_docstring
@safe_auto_docstring
class SiglipPreTrainedModel(PreTrainedModel):
config_class = SiglipConfig
base_model_prefix = "siglip"
@@ -644,7 +663,7 @@ class SiglipTextTransformer(nn.Module):
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -696,7 +715,7 @@ class SiglipTextTransformer(nn.Module):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The text model from SigLIP without any head or projection on top.
"""
@@ -717,7 +736,7 @@ class SiglipTextModel(SiglipPreTrainedModel):
self.text_model.embeddings.token_embedding = value
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -766,7 +785,7 @@ class SiglipVisionTransformer(nn.Module):
self.head = SiglipMultiheadAttentionPoolingHead(config)
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
pixel_values,
@@ -834,7 +853,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0]
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
The vision model from SigLIP without any head or projection on top.
"""
@@ -855,7 +874,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
return self.vision_model.embeddings.patch_embedding
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
pixel_values,
@@ -892,7 +911,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
)
@auto_docstring
@safe_auto_docstring
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
@@ -928,7 +947,7 @@ class SiglipModel(SiglipPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
@safe_auto_docstring
def get_text_features(
self,
input_ids: torch.Tensor | None = None,
@@ -976,7 +995,7 @@ class SiglipModel(SiglipPreTrainedModel):
return pooled_output
@auto_docstring
@safe_auto_docstring
def get_image_features(
self,
pixel_values: torch.FloatTensor | None = None,
@@ -1028,7 +1047,7 @@ class SiglipModel(SiglipPreTrainedModel):
return pooled_output
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -1131,7 +1150,7 @@ class SiglipModel(SiglipPreTrainedModel):
)
@auto_docstring(
@safe_auto_docstring(
custom_intro="""
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
@@ -1161,7 +1180,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
self.post_init()
@can_return_tuple
@auto_docstring
@safe_auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,