Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WASAPI: Improve how default devices are created to support automatic stream routing #754

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ndk-glue = "0.7"

[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.54.0", features = [
"implement",
"Win32_Media_Audio",
"Win32_Foundation",
"Win32_Devices_Properties",
Expand Down
296 changes: 201 additions & 95 deletions src/host/wasapi/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@ use std::mem;
use std::os::windows::ffi::OsStringExt;
use std::ptr;
use std::slice;
use std::sync::mpsc::Sender;
use std::sync::OnceLock;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration;

use super::com;
use super::{windows_err_to_cpal_err, windows_err_to_cpal_err_message};
use windows::core::Interface;
use windows::core::GUID;
use windows::core::{implement, IUnknown, Interface, HRESULT, PCWSTR, PROPVARIANT};
use windows::Win32::Devices::Properties;
use windows::Win32::Foundation;
use windows::Win32::Media::Audio::IAudioRenderClient;
use windows::Win32::Media::{Audio, KernelStreaming, Multimedia};
use windows::Win32::System::Com;
use windows::Win32::System::Com::{StructuredStorage, STGM_READ};
use windows::Win32::System::Com::{CoTaskMemFree, StringFromIID, StructuredStorage, STGM_READ};
use windows::Win32::System::Threading;
use windows::Win32::System::Variant::VT_LPWSTR;

Expand All @@ -40,10 +41,17 @@ struct IAudioClientWrapper(Audio::IAudioClient);
unsafe impl Send for IAudioClientWrapper {}
unsafe impl Sync for IAudioClientWrapper {}

#[derive(Debug, Clone)]
enum DeviceType {
DefaultOutput,
DefaultInput,
Specific(Audio::IMMDevice),
}

/// An opaque type that identifies an end point.
#[derive(Clone)]
pub struct Device {
device: Audio::IMMDevice,
device: DeviceType,
/// We cache an uninitialized `IAudioClient` so that we can call functions from it without
/// having to create/destroy audio clients all the time.
future_audio_client: Arc<Mutex<Option<IAudioClientWrapper>>>, // TODO: add NonZero around the ptr
Expand Down Expand Up @@ -275,66 +283,133 @@ unsafe fn format_from_waveformatex_ptr(
Some(format)
}

#[implement(Audio::IActivateAudioInterfaceCompletionHandler)]
struct CompletionHandler(Sender<windows::core::Result<IUnknown>>);

fn retrieve_result(
operation: &Audio::IActivateAudioInterfaceAsyncOperation,
) -> windows::core::Result<IUnknown> {
let mut result = HRESULT::default();
let mut interface: Option<IUnknown> = None;
unsafe {
operation.GetActivateResult(&mut result, &mut interface)?;
}
result.ok()?;
Ok(interface.unwrap())
}

impl Audio::IActivateAudioInterfaceCompletionHandler_Impl for CompletionHandler {
fn ActivateCompleted(
&self,
operation: Option<&Audio::IActivateAudioInterfaceAsyncOperation>,
) -> windows::core::Result<()> {
let result = retrieve_result(operation.unwrap());
let _ = self.0.send(result);
Ok(())
}
}

#[allow(non_snake_case)]
unsafe fn ActivateAudioInterfaceSync<P0, T>(
deviceinterfacepath: P0,
activationparams: Option<*const PROPVARIANT>,
) -> windows::core::Result<T>
where
P0: windows::core::IntoParam<PCWSTR>,
T: Interface,
{
let (sender, receiver) = std::sync::mpsc::channel();
let completion: Audio::IActivateAudioInterfaceCompletionHandler =
CompletionHandler(sender).into();
Audio::ActivateAudioInterfaceAsync(
deviceinterfacepath,
&T::IID,
activationparams,
&completion,
)?;
let result = receiver.recv_timeout(Duration::from_secs(2)).unwrap()?;
result.cast()
}

unsafe impl Send for Device {}
unsafe impl Sync for Device {}

impl Device {
pub fn name(&self) -> Result<String, DeviceNameError> {
unsafe {
// Open the device's property store.
let property_store = self
.device
.OpenPropertyStore(STGM_READ)
.expect("could not open property store");

// Get the endpoint's friendly-name property.
let mut property_value = property_store
.GetValue(&Properties::DEVPKEY_Device_FriendlyName as *const _ as *const _)
.map_err(|err| {
let description =
format!("failed to retrieve name from property store: {}", err);
let err = BackendSpecificError { description };
DeviceNameError::from(err)
})?;
match &self.device {
DeviceType::DefaultOutput => Ok("Default Ouput".to_string()),
DeviceType::DefaultInput => Ok("Default Input".to_string()),
DeviceType::Specific(device) => unsafe {
// Open the device's property store.
let property_store = device
.OpenPropertyStore(STGM_READ)
.expect("could not open property store");

// Get the endpoint's friendly-name property.
let mut property_value = property_store
.GetValue(&Properties::DEVPKEY_Device_FriendlyName as *const _ as *const _)
.map_err(|err| {
let description =
format!("failed to retrieve name from property store: {}", err);
let err = BackendSpecificError { description };
DeviceNameError::from(err)
})?;

let prop_variant = &property_value.as_raw().Anonymous.Anonymous;
let prop_variant = &property_value.as_raw().Anonymous.Anonymous;

// Read the friendly-name from the union data field, expecting a *const u16.
if prop_variant.vt != VT_LPWSTR.0 {
let description = format!(
"property store produced invalid data: {:?}",
prop_variant.vt
);
let err = BackendSpecificError { description };
return Err(err.into());
}
let ptr_utf16 = *(&prop_variant.Anonymous as *const _ as *const *const u16);
// Read the friendly-name from the union data field, expecting a *const u16.
if prop_variant.vt != VT_LPWSTR.0 {
let description = format!(
"property store produced invalid data: {:?}",
prop_variant.vt
);
let err = BackendSpecificError { description };
return Err(err.into());
}
let ptr_utf16 = *(&prop_variant.Anonymous as *const _ as *const *const u16);

// Find the length of the friendly name.
let mut len = 0;
while *ptr_utf16.offset(len) != 0 {
len += 1;
}
// Find the length of the friendly name.
let mut len = 0;
while *ptr_utf16.offset(len) != 0 {
len += 1;
}

// Create the utf16 slice and convert it into a string.
let name_slice = slice::from_raw_parts(ptr_utf16, len as usize);
let name_os_string: OsString = OsStringExt::from_wide(name_slice);
let name_string = match name_os_string.into_string() {
Ok(string) => string,
Err(os_string) => os_string.to_string_lossy().into(),
};
// Create the utf16 slice and convert it into a string.
let name_slice = slice::from_raw_parts(ptr_utf16, len as usize);
let name_os_string: OsString = OsStringExt::from_wide(name_slice);
let name_string = match name_os_string.into_string() {
Ok(string) => string,
Err(os_string) => os_string.to_string_lossy().into(),
};

// Clean up the property.
StructuredStorage::PropVariantClear(&mut property_value).ok();
// Clean up the property.
StructuredStorage::PropVariantClear(&mut property_value).ok();

Ok(name_string)
Ok(name_string)
},
}
}

#[inline]
fn from_immdevice(device: Audio::IMMDevice) -> Self {
Device {
device,
device: DeviceType::Specific(device),
future_audio_client: Arc::new(Mutex::new(None)),
}
}

#[inline]
fn default_output() -> Self {
Device {
device: DeviceType::DefaultOutput,
future_audio_client: Arc::new(Mutex::new(None)),
}
}

#[inline]
fn default_input() -> Self {
Device {
device: DeviceType::DefaultInput,
future_audio_client: Arc::new(Mutex::new(None)),
}
}
Expand All @@ -349,9 +424,25 @@ impl Device {
}

let audio_client: Audio::IAudioClient = unsafe {
// can fail if the device has been disconnected since we enumerated it, or if
// the device doesn't support playback for some reason
self.device.Activate(Com::CLSCTX_ALL, None)?
match &self.device {
DeviceType::DefaultOutput => {
let default_audio = StringFromIID(&Audio::DEVINTERFACE_AUDIO_RENDER)?;
let result = ActivateAudioInterfaceSync(PCWSTR(default_audio.as_ptr()), None);
CoTaskMemFree(Some(default_audio.as_ptr() as _));
result?
}
DeviceType::DefaultInput => {
let default_audio = StringFromIID(&Audio::DEVINTERFACE_AUDIO_CAPTURE)?;
let result = ActivateAudioInterfaceSync(PCWSTR(default_audio.as_ptr()), None);
CoTaskMemFree(Some(default_audio.as_ptr() as _));
result?
}
DeviceType::Specific(device) => {
// can fail if the device has been disconnected since we enumerated it, or if
// the device doesn't support playback for some reason
device.Activate(Com::CLSCTX_ALL, None)?
}
}
};

*lock = Some(IAudioClientWrapper(audio_client));
Expand Down Expand Up @@ -518,8 +609,14 @@ impl Device {
}

pub(crate) fn data_flow(&self) -> Audio::EDataFlow {
let endpoint = Endpoint::from(self.device.clone());
endpoint.data_flow()
match &self.device {
DeviceType::DefaultOutput => Audio::eRender,
DeviceType::DefaultInput => Audio::eCapture,
DeviceType::Specific(device) => {
let endpoint = Endpoint::from(device.clone());
endpoint.data_flow()
}
}
}

pub fn default_input_config(&self) -> Result<SupportedStreamConfig, DefaultStreamConfigError> {
Expand Down Expand Up @@ -769,40 +866,47 @@ impl Device {
impl PartialEq for Device {
#[inline]
fn eq(&self, other: &Device) -> bool {
// Use case: In order to check whether the default device has changed
// the client code might need to compare the previous default device with the current one.
// The pointer comparison (`self.device == other.device`) don't work there,
// because the pointers are different even when the default device stays the same.
//
// In this code section we're trying to use the GetId method for the device comparison, cf.
// https://docs.microsoft.com/en-us/windows/desktop/api/mmdeviceapi/nf-mmdeviceapi-immdevice-getid
unsafe {
struct IdRAII(windows::core::PWSTR);
/// RAII for device IDs.
impl Drop for IdRAII {
fn drop(&mut self) {
unsafe { Com::CoTaskMemFree(Some(self.0 .0 as *mut _)) }
}
}
// GetId only fails with E_OUTOFMEMORY and if it does, we're probably dead already.
// Plus it won't do to change the device comparison logic unexpectedly.
let id1 = self.device.GetId().expect("cpal: GetId failure");
let id1 = IdRAII(id1);
let id2 = other.device.GetId().expect("cpal: GetId failure");
let id2 = IdRAII(id2);
// 16-bit null-terminated comparison.
let mut offset = 0;
loop {
let w1: u16 = *(id1.0).0.offset(offset);
let w2: u16 = *(id2.0).0.offset(offset);
if w1 == 0 && w2 == 0 {
return true;
}
if w1 != w2 {
return false;
match (&self.device, &other.device) {
(DeviceType::DefaultOutput, DeviceType::DefaultOutput) => true,
(DeviceType::DefaultInput, DeviceType::DefaultInput) => true,
(DeviceType::Specific(dev1), DeviceType::Specific(dev2)) => {
// Use case: In order to check whether the default device has changed
// the client code might need to compare the previous default device with the current one.
// The pointer comparison (`self.device == other.device`) don't work there,
// because the pointers are different even when the default device stays the same.
//
// In this code section we're trying to use the GetId method for the device comparison, cf.
// https://docs.microsoft.com/en-us/windows/desktop/api/mmdeviceapi/nf-mmdeviceapi-immdevice-getid
unsafe {
struct IdRAII(windows::core::PWSTR);
/// RAII for device IDs.
impl Drop for IdRAII {
fn drop(&mut self) {
unsafe { Com::CoTaskMemFree(Some(self.0 .0 as *mut _)) }
}
}
// GetId only fails with E_OUTOFMEMORY and if it does, we're probably dead already.
// Plus it won't do to change the device comparison logic unexpectedly.
let id1 = dev1.GetId().expect("cpal: GetId failure");
let id1 = IdRAII(id1);
let id2 = dev2.GetId().expect("cpal: GetId failure");
let id2 = IdRAII(id2);
// 16-bit null-terminated comparison.
let mut offset = 0;
loop {
let w1: u16 = *(id1.0).0.offset(offset);
let w2: u16 = *(id2.0).0.offset(offset);
if w1 == 0 && w2 == 0 {
return true;
}
if w1 != w2 {
return false;
}
offset += 1;
}
}
offset += 1;
}
_ => false,
}
}
}
Expand Down Expand Up @@ -914,23 +1018,25 @@ impl Iterator for Devices {
}
}

fn default_device(data_flow: Audio::EDataFlow) -> Option<Device> {
unsafe {
let device = get_enumerator()
.0
.GetDefaultAudioEndpoint(data_flow, Audio::eConsole)
.ok()?;
// TODO: check specifically for `E_NOTFOUND`, and panic otherwise
Some(Device::from_immdevice(device))
}
}
//fn default_device(data_flow: Audio::EDataFlow) -> Option<Device> {
// unsafe {
// let device = get_enumerator()
// .0
// .GetDefaultAudioEndpoint(data_flow, Audio::eConsole)
// .ok()?;
// // TODO: check specifically for `E_NOTFOUND`, and panic otherwise
// Some(Device::from_immdevice(device))
// }
//}

pub fn default_input_device() -> Option<Device> {
default_device(Audio::eCapture)
//default_device(Audio::eCapture)
Some(Device::default_input())
}

pub fn default_output_device() -> Option<Device> {
default_device(Audio::eRender)
//default_device(Audio::eRender)
Some(Device::default_output())
}

/// Get the audio clock used to produce `StreamInstant`s.
Expand Down