diff --git a/src/lerobot/microphones/portaudio/microphone_portaudio.py b/src/lerobot/microphones/portaudio/microphone_portaudio.py index a601a7d04..2fe056898 100644 --- a/src/lerobot/microphones/portaudio/microphone_portaudio.py +++ b/src/lerobot/microphones/portaudio/microphone_portaudio.py @@ -123,6 +123,74 @@ class PortAudioMicrophone(Microphone): return found_microphones_info + def _configure_capture_settings(self) -> None: + """ + Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone. + + This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream. + + Raises: + RuntimeError: If one of the specified settings is not compatible with the microphone. + DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError( + f"Cannot configure settings for {self} as it is already connected." + ) + + self._validate_microphone_index() + self._validate_sample_rate() + self._validate_channels() + + def _validate_microphone_index(self) -> None: + """ "Validates the microphone index against available devices by checking if it has at least one input channel.""" + + is_index_input = ( + self.microphone_index >= 0 and sd.query_devices(self.microphone_index)["max_input_channels"] > 0 + ) + + if not is_index_input: + found_microphones_info = self.find_microphones() + available_microphones = {m["name"]: m["index"] for m in found_microphones_info} + raise RuntimeError( + f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" + ) + + def _validate_sample_rate(self) -> None: + """Validates the sample rate against the actual microphone's default sample rate.""" + + actual_sample_rate = sd.query_devices(self.microphone_index)["default_samplerate"] + + if self.sample_rate is not None: + if self.sample_rate > actual_sample_rate or self.sample_rate < 1000: + raise RuntimeError( + f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}." + ) + else: + if self.sample_rate < actual_sample_rate: + logging.warning( + "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." + ) + self.sample_rate = int(self.sample_rate) + else: + self.sample_rate = int(actual_sample_rate) + + def _validate_channels(self) -> None: + """Validates the channels against the actual microphone's maximum input channels.""" + + actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"] + + if self.channels is not None and len(self.channels) > 0: + if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels): + raise RuntimeError( + f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}." + ) + else: + self.channels = np.arange(1, actual_max_microphone_channels + 1) + + # Get channels index instead of number for slicing + self.channels_index = np.array(self.channels) - 1 + def connect(self) -> None: """ Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone. @@ -130,42 +198,7 @@ class PortAudioMicrophone(Microphone): if self.is_connected: raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") - # Check if the provided microphone index does match an input device - is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0 - - if not is_index_input: - found_microphones_info = self.find_microphones() - available_microphones = {m["name"]: m["index"] for m in found_microphones_info} - raise OSError( - f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" - ) - - # Check if provided recording parameters are compatible with the microphone - actual_microphone = sd.query_devices(self.microphone_index) - - if self.sample_rate is not None: - if self.sample_rate > actual_microphone["default_samplerate"]: - raise OSError( - f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." - ) - elif self.sample_rate < actual_microphone["default_samplerate"]: - logging.warning( - "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." - ) - else: - self.sample_rate = int(actual_microphone["default_samplerate"]) - - if self.channels is not None and len(self.channels) > 0: - if any(c > actual_microphone["max_input_channels"] for c in self.channels): - raise OSError( - f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." - ) - else: - self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1) - - # Get channels index instead of number for slicing - self.channels_index = np.array(self.channels) - 1 - + self._configure_capture_settings() # Create queues self.record_queue = process_Queue() self.read_queue = process_Queue()