mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
feat(parametrized audio processor): adding parameters for AudioProcessorStep definition
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -64,7 +65,15 @@ def make_act_pre_post_processors(
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
AudioProcessorStep(),
|
||||
AudioProcessorStep(
|
||||
output_height=224,
|
||||
output_width=224,
|
||||
output_channels=3,
|
||||
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
input_sample_rate=48000,
|
||||
intermediate_sample_rate=16000,
|
||||
n_fft=1024,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
|
||||
@@ -13,13 +13,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
from torchaudio.functional import amplitude_to_DB
|
||||
from torchaudio.transforms import MelSpectrogram, Resample
|
||||
from torchvision.transforms import Compose, Lambda, Resize
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.utils.constants import OBS_AUDIO
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@@ -38,31 +39,71 @@ class AudioProcessorStep(ObservationProcessorStep):
|
||||
- Converts the mel-spectrogram to decibels.
|
||||
- Resizes the mel-spectrogram to 224×224.
|
||||
- Converts the mel-spectrogram to a channel-first, normalized tensor.
|
||||
|
||||
Attributes:
|
||||
output_height: Height of the output mel-spectrogram image in pixels.
|
||||
output_width: Width of the output mel-spectrogram image in pixels.
|
||||
output_channels: Number of channels in the output image (3 for RGB-like format).
|
||||
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
|
||||
input_sample_rate: Original sample rate of the input audio in Hz.
|
||||
|
||||
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
|
||||
Downsampling improves the temporal resolution but reduces the frequency range.
|
||||
n_fft: Size of the FFT window for spectrogram computation.
|
||||
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
|
||||
|
||||
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
|
||||
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
|
||||
n_mels: Number of mel filter banks, computed automatically to match the output_height.
|
||||
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
|
||||
mel_spectrogram_transform: The complete audio processing pipeline.
|
||||
"""
|
||||
|
||||
# TODO(CarolinePascal) : add variable parametrization
|
||||
mel_spectrogram_transform = Compose(
|
||||
[
|
||||
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
|
||||
Resample(
|
||||
orig_freq=48000, new_freq=16000
|
||||
), # Subsampling (less samples, reduced temporal resolution, lower frequency range)
|
||||
MelSpectrogram(
|
||||
sample_rate=16000, # Subsampling (less samples, reduced temporal resolution, lower frequency range)
|
||||
n_fft=1024, # FFT window size (the bigger the window, the more frequency information, the lower the temporal resolution)
|
||||
hop_length=36, # Number of samples between frames (the smaller the hop, the higher the temporal resolution) - Value picked to match ResNet18 input and a 0.5s input
|
||||
n_mels=224, # Number of Mel bands (the more bands, the more rows in the spectrogram, the higher the frequency resolution)
|
||||
power=2, # Power spectrum
|
||||
),
|
||||
Lambda(
|
||||
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
|
||||
), # Convert to decibels
|
||||
Resize((224, 224)), # Resize spectrogram to 224×224
|
||||
Lambda(
|
||||
lambda x: x.unsqueeze(1).expand(-1, 3, -1, -1)
|
||||
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
|
||||
]
|
||||
)
|
||||
output_height: int = 224
|
||||
output_width: int = 224
|
||||
output_channels: int = 3
|
||||
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
|
||||
|
||||
input_sample_rate: int = 48000
|
||||
intermediate_sample_rate: int = 16000
|
||||
|
||||
n_fft: int = 1024
|
||||
|
||||
# Parameters computed from other parameters at initialization
|
||||
hop_length: int = field(init=False)
|
||||
n_mels: int = field(init=False)
|
||||
mel_spectrogram_transform: Compose = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.hop_length = int(
|
||||
self.intermediate_sample_rate * self.input_audio_chunk_duration
|
||||
- self.n_fft // self.output_width
|
||||
- 1
|
||||
)
|
||||
self.n_mels = self.output_height
|
||||
|
||||
self.mel_spectrogram_transform = Compose(
|
||||
[
|
||||
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
|
||||
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
|
||||
MelSpectrogram(
|
||||
sample_rate=self.intermediate_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
power=2, # Power spectrum
|
||||
),
|
||||
Lambda(
|
||||
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
|
||||
), # Convert to decibels
|
||||
Resize(
|
||||
(self.output_height, self.output_width)
|
||||
), # Resize spectrogram to output_height×output_width
|
||||
Lambda(
|
||||
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
|
||||
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
|
||||
]
|
||||
)
|
||||
|
||||
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user