mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
feat(audio ACT): removing normalization and pretrained weights because it does not really make sense
This commit is contained in:
@@ -98,7 +98,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"AUDIO": NormalizationMode.MIN_MAX,
|
||||
"AUDIO": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
@@ -111,6 +111,8 @@ class ACTConfig(PreTrainedConfig):
|
||||
replace_final_stride_with_dilation: int = False
|
||||
# Audio backbone.
|
||||
audio_backbone: str = vision_backbone
|
||||
pretrained_backbone_weights_audio: str | None = ""
|
||||
replace_final_stride_with_dilation_audio: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
dim_model: int = 512
|
||||
|
||||
@@ -335,9 +335,9 @@ class ACT(nn.Module):
|
||||
|
||||
# Backbone for audio feature extraction.
|
||||
if self.config.audio_features:
|
||||
audio_backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
|
||||
weights=config.pretrained_backbone_weights_audio,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
|
||||
Reference in New Issue
Block a user