feat(audio ACT): removing normalization and pretrained weights because it does not really make sense

This commit is contained in:
CarolinePascal
2025-05-23 16:01:35 +02:00
parent 674f5dfd75
commit 0fbcbcdb2e
2 changed files with 6 additions and 4 deletions

View File

@@ -98,7 +98,7 @@ class ACTConfig(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"AUDIO": NormalizationMode.MIN_MAX,
"AUDIO": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
@@ -111,6 +111,8 @@ class ACTConfig(PreTrainedConfig):
replace_final_stride_with_dilation: int = False
# Audio backbone.
audio_backbone: str = vision_backbone
pretrained_backbone_weights_audio: str | None = ""
replace_final_stride_with_dilation_audio: int = False
# Transformer layers.
pre_norm: bool = False
dim_model: int = 512

View File

@@ -335,9 +335,9 @@ class ACT(nn.Module):
# Backbone for audio feature extraction.
if self.config.audio_features:
audio_backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
weights=config.pretrained_backbone_weights_audio,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final