diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py index 68a17466b..baf5feb4c 100644 --- a/src/lerobot/policies/act/configuration_act.py +++ b/src/lerobot/policies/act/configuration_act.py @@ -98,7 +98,7 @@ class ACTConfig(PreTrainedConfig): normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, - "AUDIO": NormalizationMode.MIN_MAX, + "AUDIO": NormalizationMode.IDENTITY, "STATE": NormalizationMode.MEAN_STD, "ACTION": NormalizationMode.MEAN_STD, } @@ -111,6 +111,8 @@ class ACTConfig(PreTrainedConfig): replace_final_stride_with_dilation: int = False # Audio backbone. audio_backbone: str = vision_backbone + pretrained_backbone_weights_audio: str | None = "" + replace_final_stride_with_dilation_audio: int = False # Transformer layers. pre_norm: bool = False dim_model: int = 512 diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index bc16c6a6b..27a2410a2 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -335,9 +335,9 @@ class ACT(nn.Module): # Backbone for audio feature extraction. if self.config.audio_features: - audio_backbone_model = getattr(torchvision.models, config.vision_backbone)( - replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], - weights=config.pretrained_backbone_weights, + audio_backbone_model = getattr(torchvision.models, config.audio_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio], + weights=config.pretrained_backbone_weights_audio, norm_layer=FrozenBatchNorm2d, ) # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final