diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py index 6f6c1c4be..68a17466b 100644 --- a/src/lerobot/policies/act/configuration_act.py +++ b/src/lerobot/policies/act/configuration_act.py @@ -98,6 +98,7 @@ class ACTConfig(PreTrainedConfig): normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, + "AUDIO": NormalizationMode.MIN_MAX, "STATE": NormalizationMode.MEAN_STD, "ACTION": NormalizationMode.MEAN_STD, } @@ -108,6 +109,8 @@ class ACTConfig(PreTrainedConfig): vision_backbone: str = "resnet18" pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" replace_final_stride_with_dilation: int = False + # Audio backbone. + audio_backbone: str = vision_backbone # Transformer layers. pre_norm: bool = False dim_model: int = 512 @@ -170,8 +173,10 @@ class ACTConfig(PreTrainedConfig): return None def validate_features(self) -> None: - if not self.image_features and not self.env_state_feature: - raise ValueError("You must provide at least one image or the environment state among the inputs.") + if not (self.image_features or self.audio_features) and not self.env_state_feature: + raise ValueError( + "You must provide at least one image/audio or the environment state among the inputs." + ) @property def observation_delta_indices(self) -> None: diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index a5c48eb3d..bc16c6a6b 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE class ACTPolicy(PreTrainedPolicy): @@ -106,6 +106,8 @@ class ACTPolicy(PreTrainedPolicy): """ self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed + # If we are doing temporal ensembling, do online updates where we keep track of the number of actions + # we are ensembling over. if self.config.temporal_ensemble_coeff is not None: actions = self.predict_action_chunk(batch) action = self.temporal_ensembler.update(actions) @@ -331,12 +333,26 @@ class ACT(nn.Module): # Note: The forward method of this returns a dict: {"feature_map": output}. self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + # Backbone for audio feature extraction. + if self.config.audio_features: + audio_backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final + # feature map). + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.audio_backbone = IntermediateLayerGetter( + audio_backbone_model, return_layers={"layer4": "feature_map"} + ) + # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) self.decoder = ACTDecoder(config) # Transformer encoder input projections. The tokens will be structured like - # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. + # [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)]. if self.config.robot_state_feature: self.encoder_robot_state_input_proj = nn.Linear( self.config.robot_state_feature.shape[0], config.dim_model @@ -350,6 +366,10 @@ class ACT(nn.Module): self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) + if self.config.audio_features: + self.encoder_audio_feat_input_proj = nn.Conv2d( + audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1 + ) # Transformer encoder positional embeddings. n_1d_tokens = 1 # for the latent if self.config.robot_state_feature: @@ -359,6 +379,8 @@ class ACT(nn.Module): self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) if self.config.image_features: self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + if self.config.audio_features: + self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). @@ -483,6 +505,21 @@ class ACT(nn.Module): encoder_in_tokens.extend(list(cam_features)) encoder_in_pos_embed.extend(list(cam_pos_embed)) + if self.config.audio_features: + for audio in batch[OBS_AUDIO]: + audio_features = self.audio_backbone(audio)["feature_map"] + audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to( + dtype=audio_features.dtype + ) + audio_features = self.encoder_audio_feat_input_proj(audio_features) + + # Rearrange features to (sequence, batch, dim). + audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c") + audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c") + + encoder_in_tokens.extend(list(audio_features)) + encoder_in_pos_embed.extend(list(audio_pos_embed)) + # Stack all tokens along the sequence dimension. encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0) diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 727b18cef..4db7759b0 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -20,6 +20,7 @@ import torch from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, + AudioProcessorStep, DeviceProcessorStep, NormalizerProcessorStep, PolicyAction, @@ -63,6 +64,7 @@ def make_act_pre_post_processors( stats=dataset_stats, device=config.device, ), + AudioProcessorStep(), ] output_steps = [ UnnormalizerProcessorStep( diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 164f7da03..5e20df904 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .audio_processor import AudioProcessorStep from .batch_processor import AddBatchDimensionProcessorStep from .converters import ( batch_to_transition, @@ -80,6 +81,7 @@ __all__ = [ "ActionProcessorStep", "AddTeleopActionAsComplimentaryDataStep", "AddTeleopEventsAsInfoStep", + "AudioProcessorStep", "ComplementaryDataProcessorStep", "batch_to_transition", "create_transition", diff --git a/src/lerobot/processor/audio_processor.py b/src/lerobot/processor/audio_processor.py new file mode 100644 index 000000000..28f8291c9 --- /dev/null +++ b/src/lerobot/processor/audio_processor.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +from torch import Tensor +from torchaudio.functional import amplitude_to_DB +from torchaudio.transforms import MelSpectrogram, Resample +from torchvision.transforms import Compose, Lambda, Resize + +from lerobot.utils.constants import OBS_AUDIO + +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry + + +@dataclass +@ProcessorStepRegistry.register(name="audio_processor") +class AudioProcessorStep(ObservationProcessorStep): + """ + Processes audio waveform data into a mel-spectrogram image representation. + + **Audio Processing:** + - Averages waveform data over all channels. + - Resamples the waveform to 16kHz. + - Converts the waveform to a mel-spectrogram. + - Converts the mel-spectrogram to decibels. + - Resizes the mel-spectrogram to 224×224. + - Converts the mel-spectrogram to a channel-first, normalized tensor. + """ + + # TODO(CarolinePascal) : add variable parametrization + mel_spectrogram_transform = Compose( + [ + Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch) + Resample( + orig_freq=48000, new_freq=16000 + ), # Subsampling (less samples, reduced temporal resolution, lower frequency range) + MelSpectrogram( + sample_rate=16000, # Subsampling (less samples, reduced temporal resolution, lower frequency range) + n_fft=1024, # FFT window size (the bigger the window, the more frequency information, the lower the temporal resolution) + hop_length=36, # Number of samples between frames (the smaller the hop, the higher the temporal resolution) - Value picked to match ResNet18 input and a 0.5s input + n_mels=224, # Number of Mel bands (the more bands, the more rows in the spectrogram, the higher the frequency resolution) + power=2, # Power spectrum + ), + Lambda( + lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0) + ), # Convert to decibels + Resize((224, 224)), # Resize spectrogram to 224×224 + Lambda( + lambda x: x.unsqueeze(1).expand(-1, 3, -1, -1) + ), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width]. + ] + ) + + def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: + """ + Processes audio data contained in the provided observation. + """ + processed_obs = observation.copy() + + # Process single audio observation + if OBS_AUDIO in processed_obs: + audio_data = processed_obs[OBS_AUDIO] + if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples + processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data) + + # Process multiple audio observations + for key, value in processed_obs.items(): + if ( + key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3 + ): # Batch, Channels, Samples + processed_obs[key] = self.mel_spectrogram_transform(value) + + return processed_obs + + def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: + return self._process_observation(observation) diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index e1a90421f..0e9c23731 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -25,7 +25,7 @@ from dataclasses import dataclass, field from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from .core import EnvTransition, PolicyAction from .pipeline import ( @@ -88,6 +88,8 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep): - State vectors (1D tensors). - Single images (3D tensors). - Dictionaries of multiple images (3D tensors). + - Single audio waveforms (2D tensors). + - Dictionaries of multiple audio waveforms (2D tensors). """ def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: @@ -117,6 +119,18 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep): for key, value in observation.items(): if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: observation[key] = value.unsqueeze(0) + + # Process single audio observation - add batch dim if 2D + if OBS_AUDIO in observation: + audio_value = observation[OBS_AUDIO] + if isinstance(audio_value, Tensor) and audio_value.dim() == 2: + observation[OBS_AUDIO] = audio_value.unsqueeze(0) + + # Process multiple audio observations - add batch dim if 2D + for key, value in observation.items(): + if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2: + observation[key] = value.unsqueeze(0) + return observation def transform_features(