From 2a4c223ec79aed452e0bfb1e830bd4e888557819 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 30 Oct 2025 18:22:33 +0100 Subject: [PATCH] feat(parametrized audio processor): adding parameters for AudioProcessorStep definition --- src/lerobot/policies/act/processor_act.py | 11 ++- src/lerobot/processor/audio_processor.py | 89 +++++++++++++++++------ 2 files changed, 75 insertions(+), 25 deletions(-) diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 4db7759b0..ba5e1dc78 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -17,6 +17,7 @@ from typing import Any import torch +from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -64,7 +65,15 @@ def make_act_pre_post_processors( stats=dataset_stats, device=config.device, ), - AudioProcessorStep(), + AudioProcessorStep( + output_height=224, + output_width=224, + output_channels=3, + input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION, + input_sample_rate=48000, + intermediate_sample_rate=16000, + n_fft=1024, + ), ] output_steps = [ UnnormalizerProcessorStep( diff --git a/src/lerobot/processor/audio_processor.py b/src/lerobot/processor/audio_processor.py index 28f8291c9..8102b0ab1 100644 --- a/src/lerobot/processor/audio_processor.py +++ b/src/lerobot/processor/audio_processor.py @@ -13,13 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch import Tensor from torchaudio.functional import amplitude_to_DB from torchaudio.transforms import MelSpectrogram, Resample from torchvision.transforms import Compose, Lambda, Resize +from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION from lerobot.utils.constants import OBS_AUDIO from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -38,31 +39,71 @@ class AudioProcessorStep(ObservationProcessorStep): - Converts the mel-spectrogram to decibels. - Resizes the mel-spectrogram to 224×224. - Converts the mel-spectrogram to a channel-first, normalized tensor. + + Attributes: + output_height: Height of the output mel-spectrogram image in pixels. + output_width: Width of the output mel-spectrogram image in pixels. + output_channels: Number of channels in the output image (3 for RGB-like format). + input_audio_chunk_duration: Duration of the input audio chunk in seconds. + input_sample_rate: Original sample rate of the input audio in Hz. + + intermediate_sample_rate: Reduced intermediate sample rate in Hz. + Downsampling improves the temporal resolution but reduces the frequency range. + n_fft: Size of the FFT window for spectrogram computation. + Increasing the window size increases the frequency resolution but decreases the temporal resolution. + + hop_length: Number of samples between successive frames, computed automatically to match the output_width. + Decreasing the hop length increases the temporal resolution but decreases the frequency resolution. + n_mels: Number of mel filter banks, computed automatically to match the output_height. + Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution. + mel_spectrogram_transform: The complete audio processing pipeline. """ - # TODO(CarolinePascal) : add variable parametrization - mel_spectrogram_transform = Compose( - [ - Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch) - Resample( - orig_freq=48000, new_freq=16000 - ), # Subsampling (less samples, reduced temporal resolution, lower frequency range) - MelSpectrogram( - sample_rate=16000, # Subsampling (less samples, reduced temporal resolution, lower frequency range) - n_fft=1024, # FFT window size (the bigger the window, the more frequency information, the lower the temporal resolution) - hop_length=36, # Number of samples between frames (the smaller the hop, the higher the temporal resolution) - Value picked to match ResNet18 input and a 0.5s input - n_mels=224, # Number of Mel bands (the more bands, the more rows in the spectrogram, the higher the frequency resolution) - power=2, # Power spectrum - ), - Lambda( - lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0) - ), # Convert to decibels - Resize((224, 224)), # Resize spectrogram to 224×224 - Lambda( - lambda x: x.unsqueeze(1).expand(-1, 3, -1, -1) - ), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width]. - ] - ) + output_height: int = 224 + output_width: int = 224 + output_channels: int = 3 + input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION + + input_sample_rate: int = 48000 + intermediate_sample_rate: int = 16000 + + n_fft: int = 1024 + + # Parameters computed from other parameters at initialization + hop_length: int = field(init=False) + n_mels: int = field(init=False) + mel_spectrogram_transform: Compose = field(init=False, repr=False) + + def __post_init__(self): + self.hop_length = int( + self.intermediate_sample_rate * self.input_audio_chunk_duration + - self.n_fft // self.output_width + - 1 + ) + self.n_mels = self.output_height + + self.mel_spectrogram_transform = Compose( + [ + Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch) + Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate), + MelSpectrogram( + sample_rate=self.intermediate_sample_rate, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + power=2, # Power spectrum + ), + Lambda( + lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0) + ), # Convert to decibels + Resize( + (self.output_height, self.output_width) + ), # Resize spectrogram to output_height×output_width + Lambda( + lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1) + ), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width]. + ] + ) def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: """