mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Adding multiprocessing support for audio recording
This commit is contained in:
@@ -78,9 +78,9 @@ def decode_audio_torchaudio(
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough
|
||||
buffer_chunk_size = -1, #No dropping frames
|
||||
format = "fltp", #Format as float32
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
@@ -93,7 +93,9 @@ def decode_audio_torchaudio(
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}")
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk)
|
||||
|
||||
|
||||
@@ -20,10 +20,11 @@ import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from multiprocessing import Event as process_Event, JoinableQueue as process_Queue, Process
|
||||
from os import getcwd
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Event, Thread
|
||||
from queue import Empty, Queue as thread_Queue
|
||||
from threading import Event, Event as thread_Event, Thread
|
||||
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
@@ -130,7 +131,7 @@ class Microphone:
|
||||
self.config = config
|
||||
self.microphone_index = config.microphone_index
|
||||
|
||||
#Store the recording sample rate and channels
|
||||
# Store the recording sample rate and channels
|
||||
self.sample_rate = config.sample_rate
|
||||
self.channels = config.channels
|
||||
|
||||
@@ -138,8 +139,8 @@ class Microphone:
|
||||
self.stream = None
|
||||
|
||||
# Thread-safe concurrent queue to store the recorded/read audio
|
||||
self.record_queue = Queue()
|
||||
self.read_queue = Queue()
|
||||
self.record_queue = None
|
||||
self.read_queue = None
|
||||
|
||||
# Thread to handle data reading and file writing in a separate thread (safely)
|
||||
self.record_thread = None
|
||||
@@ -166,13 +167,15 @@ class Microphone:
|
||||
# Check if provided recording parameters are compatible with the microphone
|
||||
actual_microphone = sd.query_devices(self.microphone_index)
|
||||
|
||||
if self.sample_rate is not None :
|
||||
if self.sample_rate is not None:
|
||||
if self.sample_rate > actual_microphone["default_samplerate"]:
|
||||
raise OSError(
|
||||
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
||||
)
|
||||
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
||||
logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.")
|
||||
logging.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = int(actual_microphone["default_samplerate"])
|
||||
|
||||
@@ -191,7 +194,7 @@ class Microphone:
|
||||
self.stream = sd.InputStream(
|
||||
device=self.microphone_index,
|
||||
samplerate=self.sample_rate,
|
||||
channels=max(self.channels)+1,
|
||||
channels=max(self.channels) + 1,
|
||||
dtype="float32",
|
||||
callback=self._audio_callback,
|
||||
)
|
||||
@@ -208,13 +211,24 @@ class Microphone:
|
||||
self.record_queue.put(indata[:, self.channels])
|
||||
self.read_queue.put(indata[:, self.channels])
|
||||
|
||||
def _record_loop(self, output_file: Path) -> None:
|
||||
#Can only be run on a single process/thread for file writing safety
|
||||
with sf.SoundFile(output_file, mode='x', samplerate=self.sample_rate,
|
||||
channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file:
|
||||
while not self.record_stop_event.is_set():
|
||||
file.write(self.record_queue.get())
|
||||
# self.record_queue.task_done()
|
||||
@staticmethod
|
||||
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with sf.SoundFile(
|
||||
output_file,
|
||||
mode="x",
|
||||
samplerate=sample_rate,
|
||||
channels=max(channels) + 1,
|
||||
subtype=sf.default_subtype(output_file.suffix[1:]),
|
||||
) as file:
|
||||
while not event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.02)
|
||||
) # Timeout set as twice the usual sounddevice buffer size
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
@@ -222,17 +236,15 @@ class Microphone:
|
||||
-> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage
|
||||
-> CONS : Reading duration does not scale well with the number of channels and reading duration
|
||||
"""
|
||||
try:
|
||||
audio_readings = self.read_queue.queue
|
||||
except Queue.Empty:
|
||||
audio_readings = np.empty((0, len(self.channels)))
|
||||
else:
|
||||
# TODO(CarolinePascal): Check if this is the fastest way to do it
|
||||
self.read_queue = Queue()
|
||||
with self.read_queue.mutex:
|
||||
self.read_queue.queue.clear()
|
||||
# self.read_queue.all_tasks_done.notify_all()
|
||||
audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels))
|
||||
audio_readings = np.empty((0, len(self.channels)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
self.read_queue = thread_Queue()
|
||||
|
||||
return audio_readings
|
||||
|
||||
@@ -254,30 +266,49 @@ class Microphone:
|
||||
|
||||
return audio_readings
|
||||
|
||||
def start_recording(self, output_file: str | None = None) -> None:
|
||||
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||
|
||||
self.read_queue = Queue()
|
||||
with self.read_queue.mutex:
|
||||
self.read_queue.queue.clear()
|
||||
# self.read_queue.all_tasks_done.notify_all()
|
||||
# Reset queues
|
||||
self.read_queue = thread_Queue()
|
||||
if multiprocessing:
|
||||
self.record_queue = process_Queue()
|
||||
else:
|
||||
self.record_queue = thread_Queue()
|
||||
|
||||
self.record_queue = Queue()
|
||||
with self.record_queue.mutex:
|
||||
self.record_queue.queue.clear()
|
||||
# self.record_queue.all_tasks_done.notify_all()
|
||||
|
||||
# Recording case
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
if output_file.exists():
|
||||
output_file.unlink()
|
||||
|
||||
self.record_stop_event = Event()
|
||||
self.record_thread = Thread(target=self._record_loop, args=(output_file,))
|
||||
if multiprocessing:
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_thread = Process(
|
||||
target=Microphone._record_loop,
|
||||
args=(
|
||||
self.record_queue,
|
||||
self.record_stop_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.record_stop_event = thread_Event()
|
||||
self.record_thread = Thread(
|
||||
target=Microphone._record_loop,
|
||||
args=(
|
||||
self.record_queue,
|
||||
self.record_stop_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.record_thread.daemon = True
|
||||
self.record_thread.start()
|
||||
|
||||
@@ -290,18 +321,18 @@ class Microphone:
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
if self.stream.active:
|
||||
self.stream.stop() # Wait for all buffers to be processed
|
||||
# Remark : stream.abort() flushes the buffers !
|
||||
self.is_recording = False
|
||||
|
||||
if self.record_thread is not None:
|
||||
# self.record_queue.join()
|
||||
self.record_queue.join()
|
||||
self.record_stop_event.set()
|
||||
self.record_thread.join()
|
||||
self.record_thread = None
|
||||
self.record_stop_event = None
|
||||
|
||||
if self.stream.active:
|
||||
self.stream.stop() # Wait for all buffers to be processed
|
||||
# Remark : stream.abort() flushes the buffers !
|
||||
|
||||
self.is_recording = False
|
||||
self.is_writing = False
|
||||
|
||||
self.logs["stop_timestamp"] = capture_timestamp_utc()
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ class KochFollower(Robot):
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) for mic in self.microphones
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -102,7 +102,8 @@ class LeKiwi(Robot):
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) for mic in self.microphones
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -78,7 +78,8 @@ class SOFollower(Robot):
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels) for mic in self.microphones
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
|
||||
Reference in New Issue
Block a user