From 52c424c5ebd7ff3c91c951d1e54cd16ceb290770 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 9 Apr 2025 14:59:29 +0200 Subject: [PATCH] Adding multiprocessing support for audio recording --- src/lerobot/datasets/audio_utils.py | 10 +- src/lerobot/microphones/microphone.py | 123 +++++++++++------- .../robots/koch_follower/koch_follower.py | 3 +- src/lerobot/robots/lekiwi/lekiwi.py | 3 +- src/lerobot/robots/so_follower/so_follower.py | 3 +- 5 files changed, 89 insertions(+), 53 deletions(-) diff --git a/src/lerobot/datasets/audio_utils.py b/src/lerobot/datasets/audio_utils.py index d073109a0..da4d97f82 100644 --- a/src/lerobot/datasets/audio_utils.py +++ b/src/lerobot/datasets/audio_utils.py @@ -78,9 +78,9 @@ def decode_audio_torchaudio( # TODO(CarolinePascal) : sort timestamps ? reader.add_basic_audio_stream( - frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough - buffer_chunk_size = -1, #No dropping frames - format = "fltp", #Format as float32 + frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + format="fltp", # Format as float32 ) audio_chunks = [] @@ -93,7 +93,9 @@ def decode_audio_torchaudio( current_audio_chunk = reader.pop_chunks()[0] if log_loaded_timestamps: - logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}") + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" + ) audio_chunks.append(current_audio_chunk) diff --git a/src/lerobot/microphones/microphone.py b/src/lerobot/microphones/microphone.py index 092ab95c9..41c63bb1a 100644 --- a/src/lerobot/microphones/microphone.py +++ b/src/lerobot/microphones/microphone.py @@ -20,10 +20,11 @@ import argparse import logging import shutil import time +from multiprocessing import Event as process_Event, JoinableQueue as process_Queue, Process from os import getcwd from pathlib import Path -from queue import Queue -from threading import Event, Thread +from queue import Empty, Queue as thread_Queue +from threading import Event, Event as thread_Event, Thread import numpy as np import sounddevice as sd @@ -130,7 +131,7 @@ class Microphone: self.config = config self.microphone_index = config.microphone_index - #Store the recording sample rate and channels + # Store the recording sample rate and channels self.sample_rate = config.sample_rate self.channels = config.channels @@ -138,8 +139,8 @@ class Microphone: self.stream = None # Thread-safe concurrent queue to store the recorded/read audio - self.record_queue = Queue() - self.read_queue = Queue() + self.record_queue = None + self.read_queue = None # Thread to handle data reading and file writing in a separate thread (safely) self.record_thread = None @@ -166,13 +167,15 @@ class Microphone: # Check if provided recording parameters are compatible with the microphone actual_microphone = sd.query_devices(self.microphone_index) - if self.sample_rate is not None : + if self.sample_rate is not None: if self.sample_rate > actual_microphone["default_samplerate"]: raise OSError( f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." ) elif self.sample_rate < actual_microphone["default_samplerate"]: - logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.") + logging.warning( + "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." + ) else: self.sample_rate = int(actual_microphone["default_samplerate"]) @@ -191,7 +194,7 @@ class Microphone: self.stream = sd.InputStream( device=self.microphone_index, samplerate=self.sample_rate, - channels=max(self.channels)+1, + channels=max(self.channels) + 1, dtype="float32", callback=self._audio_callback, ) @@ -208,13 +211,24 @@ class Microphone: self.record_queue.put(indata[:, self.channels]) self.read_queue.put(indata[:, self.channels]) - def _record_loop(self, output_file: Path) -> None: - #Can only be run on a single process/thread for file writing safety - with sf.SoundFile(output_file, mode='x', samplerate=self.sample_rate, - channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: - while not self.record_stop_event.is_set(): - file.write(self.record_queue.get()) - # self.record_queue.task_done() + @staticmethod + def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: + # Can only be run on a single process/thread for file writing safety + with sf.SoundFile( + output_file, + mode="x", + samplerate=sample_rate, + channels=max(channels) + 1, + subtype=sf.default_subtype(output_file.suffix[1:]), + ) as file: + while not event.is_set(): + try: + file.write( + queue.get(timeout=0.02) + ) # Timeout set as twice the usual sounddevice buffer size + queue.task_done() + except Empty: + continue def _read(self) -> np.ndarray: """ @@ -222,17 +236,15 @@ class Microphone: -> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage -> CONS : Reading duration does not scale well with the number of channels and reading duration """ - try: - audio_readings = self.read_queue.queue - except Queue.Empty: - audio_readings = np.empty((0, len(self.channels))) - else: - # TODO(CarolinePascal): Check if this is the fastest way to do it - self.read_queue = Queue() - with self.read_queue.mutex: - self.read_queue.queue.clear() - # self.read_queue.all_tasks_done.notify_all() - audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels)) + audio_readings = np.empty((0, len(self.channels))) + + while True: + try: + audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) + except Empty: + break + + self.read_queue = thread_Queue() return audio_readings @@ -254,30 +266,49 @@ class Microphone: return audio_readings - def start_recording(self, output_file: str | None = None) -> None: + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: if not self.is_connected: raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if self.is_recording: raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") - self.read_queue = Queue() - with self.read_queue.mutex: - self.read_queue.queue.clear() - # self.read_queue.all_tasks_done.notify_all() + # Reset queues + self.read_queue = thread_Queue() + if multiprocessing: + self.record_queue = process_Queue() + else: + self.record_queue = thread_Queue() - self.record_queue = Queue() - with self.record_queue.mutex: - self.record_queue.queue.clear() - # self.record_queue.all_tasks_done.notify_all() - - # Recording case + # Write recordings into a file if output_file is provided if output_file is not None: output_file = Path(output_file) if output_file.exists(): output_file.unlink() - self.record_stop_event = Event() - self.record_thread = Thread(target=self._record_loop, args=(output_file,)) + if multiprocessing: + self.record_stop_event = process_Event() + self.record_thread = Process( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) + else: + self.record_stop_event = thread_Event() + self.record_thread = Thread( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) self.record_thread.daemon = True self.record_thread.start() @@ -290,18 +321,18 @@ class Microphone: if not self.is_recording: raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") + if self.stream.active: + self.stream.stop() # Wait for all buffers to be processed + # Remark : stream.abort() flushes the buffers ! + self.is_recording = False + if self.record_thread is not None: - # self.record_queue.join() + self.record_queue.join() self.record_stop_event.set() self.record_thread.join() self.record_thread = None self.record_stop_event = None - - if self.stream.active: - self.stream.stop() # Wait for all buffers to be processed - # Remark : stream.abort() flushes the buffers ! - - self.is_recording = False + self.is_writing = False self.logs["stop_timestamp"] = capture_timestamp_utc() diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 9921d416b..b1eb0091d 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -77,7 +77,8 @@ class KochFollower(Robot): @property def _microphones_ft(self) -> dict[str, tuple]: return { - mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) for mic in self.microphones + mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) + for mic in self.microphones } @cached_property diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 5bf3ca41b..05f366689 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -102,7 +102,8 @@ class LeKiwi(Robot): @property def _microphones_ft(self) -> dict[str, tuple]: return { - mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) for mic in self.microphones + mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) + for mic in self.microphones } @cached_property diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index 17ed4c2c2..4de4b8a99 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -78,7 +78,8 @@ class SOFollower(Robot): @property def _microphones_ft(self) -> dict[str, tuple]: return { - mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) for mic in self.microphones + mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) + for mic in self.microphones } @cached_property