mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
style(config validation): storing microphone config validation in dedicated methods
This commit is contained in:
@@ -123,6 +123,74 @@ class PortAudioMicrophone(Microphone):
|
||||
|
||||
return found_microphones_info
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
|
||||
|
||||
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If one of the specified settings is not compatible with the microphone.
|
||||
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"Cannot configure settings for {self} as it is already connected."
|
||||
)
|
||||
|
||||
self._validate_microphone_index()
|
||||
self._validate_sample_rate()
|
||||
self._validate_channels()
|
||||
|
||||
def _validate_microphone_index(self) -> None:
|
||||
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||
|
||||
is_index_input = (
|
||||
self.microphone_index >= 0 and sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||
)
|
||||
|
||||
if not is_index_input:
|
||||
found_microphones_info = self.find_microphones()
|
||||
available_microphones = {m["name"]: m["index"] for m in found_microphones_info}
|
||||
raise RuntimeError(
|
||||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||
)
|
||||
|
||||
def _validate_sample_rate(self) -> None:
|
||||
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||
|
||||
actual_sample_rate = sd.query_devices(self.microphone_index)["default_samplerate"]
|
||||
|
||||
if self.sample_rate is not None:
|
||||
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||
raise RuntimeError(
|
||||
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||
)
|
||||
else:
|
||||
if self.sample_rate < actual_sample_rate:
|
||||
logging.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
self.sample_rate = int(self.sample_rate)
|
||||
else:
|
||||
self.sample_rate = int(actual_sample_rate)
|
||||
|
||||
def _validate_channels(self) -> None:
|
||||
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||
|
||||
actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"]
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels):
|
||||
raise RuntimeError(
|
||||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}."
|
||||
)
|
||||
else:
|
||||
self.channels = np.arange(1, actual_max_microphone_channels + 1)
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||
@@ -130,42 +198,7 @@ class PortAudioMicrophone(Microphone):
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||
|
||||
# Check if the provided microphone index does match an input device
|
||||
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||
|
||||
if not is_index_input:
|
||||
found_microphones_info = self.find_microphones()
|
||||
available_microphones = {m["name"]: m["index"] for m in found_microphones_info}
|
||||
raise OSError(
|
||||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||
)
|
||||
|
||||
# Check if provided recording parameters are compatible with the microphone
|
||||
actual_microphone = sd.query_devices(self.microphone_index)
|
||||
|
||||
if self.sample_rate is not None:
|
||||
if self.sample_rate > actual_microphone["default_samplerate"]:
|
||||
raise OSError(
|
||||
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
||||
)
|
||||
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
||||
logging.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = int(actual_microphone["default_samplerate"])
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if any(c > actual_microphone["max_input_channels"] for c in self.channels):
|
||||
raise OSError(
|
||||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
|
||||
)
|
||||
else:
|
||||
self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
self._configure_capture_settings()
|
||||
# Create queues
|
||||
self.record_queue = process_Queue()
|
||||
self.read_queue = process_Queue()
|
||||
|
||||
Reference in New Issue
Block a user