From 3bbd161cfdd7b5311b5de8aea7923b5f28dbe0a6 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 17 Apr 2025 20:02:45 +0200 Subject: [PATCH] [skip ci] feat(audio recording): adding new asyn start_recording, stop_recording and read functions to avoid for loop delays --- src/lerobot/datasets/lerobot_dataset.py | 16 ++++- .../portaudio/microphone_portaudio.py | 10 +-- src/lerobot/microphones/utils.py | 70 +++++++++++++++++++ 3 files changed, 91 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index ccbee263a..f9ffd6694 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -80,6 +80,7 @@ from lerobot.datasets.video_utils import ( get_video_info, ) from lerobot.microphones import Microphone +from lerobot.microphones.utils import async_microphones_start_recording from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" @@ -1312,7 +1313,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["size"] += 1 - def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None: + def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None: """ Starts recording audio data provided by the microphone and directly writes it in a .wav file. """ @@ -1320,6 +1321,19 @@ class LeRobotDataset(torch.utils.data.Dataset): audio_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key) microphone.start_recording(output_file=audio_file) + def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None: + """ + Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files. + """ + + output_files = [] + for microphone_key in microphones: + output_files.append( + self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key) + ) + + async_microphones_start_recording(microphones, output_files) + def save_episode( self, episode_data: dict | None = None, diff --git a/src/lerobot/microphones/portaudio/microphone_portaudio.py b/src/lerobot/microphones/portaudio/microphone_portaudio.py index 106f9060d..d2f320fc6 100644 --- a/src/lerobot/microphones/portaudio/microphone_portaudio.py +++ b/src/lerobot/microphones/portaudio/microphone_portaudio.py @@ -185,9 +185,10 @@ class PortAudioMicrophone(Microphone): logging.warning(status) # Slicing makes copy unnecessary # Two separate queues are necessary because .get() also pops the data from the queue + # Remark: this also ensures that file-recorded data and chunk-audio data are the same. if self.is_writing: - self.record_queue.put(indata[:, self.channels_index]) - self.read_queue.put(indata[:, self.channels_index]) + self.record_queue.put_nowait(indata[:, self.channels_index]) + self.read_queue.put_nowait(indata[:, self.channels_index]) @staticmethod def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: @@ -206,8 +207,8 @@ class PortAudioMicrophone(Microphone): while not event.is_set(): try: file.write( - queue.get(timeout=0.02) - ) # Timeout set as twice the usual sounddevice buffer size + queue.get(timeout=0.01) + ) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread. queue.task_done() except Empty: continue @@ -257,6 +258,7 @@ class PortAudioMicrophone(Microphone): ) -> None: """ Starts the recording of the microphone. If output_file is provided, the audio will be written to this file. + Remark: multiprocessing is implemented, but does not work well with sounddevice (launching delays, tricky memory sharing, sounddevice streams are not picklable (even with dill #pathos), etc.). """ if not self.is_connected: raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") diff --git a/src/lerobot/microphones/utils.py b/src/lerobot/microphones/utils.py index 4d2b78b81..675c51e1f 100644 --- a/src/lerobot/microphones/utils.py +++ b/src/lerobot/microphones/utils.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from queue import Queue +from threading import Thread + from .configs import MicrophoneConfig from .microphone import Microphone @@ -28,3 +31,70 @@ def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig raise ValueError(f"The microphone type '{cfg.type}' is not valid.") return microphones + + +def async_microphones_start_recording( + microphones: dict[str, Microphone], + output_files: list[str | None] | None, + multiprocessing: bool = False, + overwrite: bool = True, +): + """ + Starts recording on multiple microphones asynchronously to avoid delays + """ + + start_recording_threads = [] + if output_files is None: + output_files = [None] * len(microphones) + + for microphone, output_file in zip(microphones.values(), output_files, strict=False): + start_recording_threads.append( + Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite)) + ) + + for thread in start_recording_threads: + thread.start() + for thread in start_recording_threads: + thread.join() + + +def async_microphones_stop_recording(microphones: dict[str, Microphone]): + """ + Stops recording on multiple microphones asynchronously to avoid delays + """ + + stop_recording_threads = [] + + for microphone in microphones.values(): + stop_recording_threads.append(Thread(target=microphone.stop_recording)) + + for thread in stop_recording_threads: + thread.start() + for thread in stop_recording_threads: + thread.join() + + +def async_microphones_read(microphones: dict[str, Microphone]): + """ + Reads from multiple microphones asynchronously to avoid delays + """ + + read_threads = [] + read_queue = Queue() + + for microphone_key, microphone in microphones.items(): + read_threads.append( + Thread( + target=lambda microphone, output, microphone_key: output.put_nowait( + {microphone_key: microphone.read()} + ), + args=(microphone, read_queue, microphone_key), + ) + ) + + for thread in read_threads: + thread.start() + for thread in read_threads: + thread.join() + + return dict(kv for d in read_queue.queue for kv in d.items())