Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KSA device removal handler #408

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified build/staging/app-sdk/Arm64EC/mididiag.exe
Binary file not shown.
2 changes: 1 addition & 1 deletion build/staging/version/BundleInfo.wxi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Include>
<?define SetupVersionName="Developer Preview 7 Arm64" ?>
<?define SetupVersionNumber="1.0.24260.2222" ?>
<?define SetupVersionNumber="1.0.24281.1756" ?>
</Include>
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ CMidi2KSMidiEndpointManager::Cleanup()
m_DeviceStopped.revoke();
m_DeviceEnumerationCompleted.revoke();

m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();
m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();

return S_OK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ CMidi2KSAggregateMidiEndpointManager::Initialize(
MIDI_TRACE_EVENT_INFO,
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this")
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD)
);

RETURN_HR_IF(E_INVALIDARG, nullptr == midiDeviceManager);
Expand Down Expand Up @@ -86,7 +87,7 @@ CMidi2KSAggregateMidiEndpointManager::Initialize(
_Use_decl_annotations_
HRESULT
CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
KsAggregateEndpointDefinition& MasterEndpointDefinition
KsAggregateEndpointDefinition& masterEndpointDefinition
)
{
TraceLoggingWrite(
Expand All @@ -96,21 +97,24 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Creating aggregate UMP endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name")
);
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(masterEndpointDefinition.EndpointDeviceInstanceId.c_str(), "Endpoint Device Instance Id"),
TraceLoggingWideString(masterEndpointDefinition.FilterDeviceId.c_str(), "Filter Device Id"),
TraceLoggingWideString(masterEndpointDefinition.ParentDeviceInstanceId.c_str(), "Parent Device Instance Id")
);


// we require at least one valid pin
RETURN_HR_IF(E_INVALIDARG, MasterEndpointDefinition.Pins.size() < 1);
RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition.Pins.size() < 1);

std::vector<DEVPROPERTY> interfaceDevProperties;

MIDIENDPOINTCOMMONPROPERTIES commonProperties{};
commonProperties.AbstractionLayerGuid = ABSTRACTION_LAYER_GUID;
commonProperties.EndpointPurpose = MidiEndpointDevicePurposePropertyValue::NormalMessageEndpoint;
commonProperties.FriendlyName = MasterEndpointDefinition.EndpointName.c_str();
commonProperties.FriendlyName = masterEndpointDefinition.EndpointName.c_str();
commonProperties.TransportCode = TRANSPORT_CODE;
commonProperties.TransportSuppliedEndpointName = MasterEndpointDefinition.FilterName.c_str();
commonProperties.TransportSuppliedEndpointName = masterEndpointDefinition.FilterName.c_str();
commonProperties.TransportSuppliedEndpointDescription = nullptr;
commonProperties.UserSuppliedEndpointName = nullptr;
commonProperties.UserSuppliedEndpointDescription = nullptr;
Expand All @@ -126,10 +130,10 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
commonProperties.SupportsMidi2ProtocolDefaultValue = false;

interfaceDevProperties.push_back({ {DEVPKEY_KsMidiPort_KsFilterInterfaceId, DEVPROP_STORE_SYSTEM, nullptr},
DEVPROP_TYPE_STRING, static_cast<ULONG>((MasterEndpointDefinition.FilterDeviceId.length() + 1) * sizeof(WCHAR)), (PVOID)MasterEndpointDefinition.FilterDeviceId.c_str() });
DEVPROP_TYPE_STRING, static_cast<ULONG>((masterEndpointDefinition.FilterDeviceId.length() + 1) * sizeof(WCHAR)), (PVOID)masterEndpointDefinition.FilterDeviceId.c_str() });

interfaceDevProperties.push_back({ {DEVPKEY_KsTransport, DEVPROP_STORE_SYSTEM, nullptr },
DEVPROP_TYPE_UINT32, static_cast<ULONG>(sizeof(UINT32)), (PVOID)&MasterEndpointDefinition.TransportCapability });
DEVPROP_TYPE_UINT32, static_cast<ULONG>(sizeof(UINT32)), (PVOID)&masterEndpointDefinition.TransportCapability });

// create group terminal blocks and the pin map

Expand All @@ -140,7 +144,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
KSMIDI_PIN_MAP pinMap{ };
std::vector<internal::GroupTerminalBlockInternal> blocks;

for (auto const& pin : MasterEndpointDefinition.Pins)
for (auto const& pin : masterEndpointDefinition.Pins)
{
internal::GroupTerminalBlockInternal gtb;

Expand Down Expand Up @@ -218,9 +222,9 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
SW_DEVICE_CREATE_INFO createInfo{ };

createInfo.cbSize = sizeof(createInfo);
createInfo.pszInstanceId = MasterEndpointDefinition.EndpointDeviceInstanceId.c_str();
createInfo.pszInstanceId = masterEndpointDefinition.EndpointDeviceInstanceId.c_str();
createInfo.CapabilityFlags = SWDeviceCapabilitiesNone;
createInfo.pszDeviceDescription = MasterEndpointDefinition.EndpointName.c_str();
createInfo.pszDeviceDescription = masterEndpointDefinition.EndpointName.c_str();

// Call the device manager and finish the creation

Expand All @@ -231,7 +235,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(

LOG_IF_FAILED(
swdCreationResult = m_MidiDeviceManager->ActivateEndpoint(
MasterEndpointDefinition.ParentDeviceInstanceId.c_str(),
masterEndpointDefinition.ParentDeviceInstanceId.c_str(),
false, // TODO: create UMP only, handle the MIDI 1.0 compat ports
MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint
&commonProperties,
Expand All @@ -254,13 +258,18 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(newDeviceInterfaceId, "endpoint id")
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(newDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD)
);

// TODO: Add to internal endpoint manager, and also return the device interface id
// todo: return new device interface id


return swdCreationResult; // TODO change this to account for other steps

// Add to internal endpoint manager
m_availableEndpointDefinitions.push_back(std::move(masterEndpointDefinition));

return swdCreationResult;
}
else
{
Expand All @@ -271,7 +280,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD)
);

Expand Down Expand Up @@ -358,7 +367,7 @@ CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded(
auto prop = properties.Lookup(winrt::to_hstring(L"System.Devices.DeviceInstanceId"));
RETURN_HR_IF_NULL(E_INVALIDARG, prop);
endpointDefinition.ParentDeviceInstanceId = winrt::unbox_value<winrt::hstring>(prop).c_str();
endpointDefinition.FilterDeviceId = device.Id().c_str();
endpointDefinition.FilterDeviceId = internal::NormalizeDeviceInstanceIdWStringCopy(device.Id().c_str());

// get the parent device
auto parentDeviceInfo = DeviceInformation::CreateFromIdAsync(endpointDefinition.ParentDeviceInstanceId,
Expand Down Expand Up @@ -584,43 +593,51 @@ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved(DeviceWatcher, Dev
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(device.Id().c_str(), "device id")
);

auto cleanDeviceId = internal::NormalizeDeviceInstanceIdWStringCopy(device.Id().c_str());

// TODO
// the interface is no longer active, search through our m_AvailableMidiPins to identify
// every entry with this filter interface id, and remove the SWD and remove the pin(s) from
// the m_AvailableMidiPins list.
do
{
auto item = std::find_if(m_availableEndpointDefinitions.begin(),
m_availableEndpointDefinitions.end(), [&](const KsAggregateEndpointDefinition& endpointDefinition)
{
if (endpointDefinition.FilterDeviceId == cleanDeviceId)
{
return true;
}

return false;
});


if (item == m_availableEndpointDefinitions.end())
{
break;
}

TraceLoggingWrite(
MidiKSAggregateAbstractionTelemetryProvider::Provider(),
MIDI_TRACE_EVENT_INFO,
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Found device to remove", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(device.Id().c_str(), "device id")
);

// the interface is no longer active, search through our m_AvailableMidiPins to identify
// every entry with this filter interface id, and remove the SWD and remove the pin(s) from
// the m_AvailableMidiPins list.
//do
//{
// auto item = std::find_if(m_AvailableMidiPins.begin(), m_AvailableMidiPins.end(), [&](const std::unique_ptr<MIDI_PIN_INFO>& Pin)
// {
// // if this interface id already activated, then we cannot activate/create a second time,
// if (device.Id() == Pin->Id)
// {
// return true;
// }

// return false;
// });

// if (item == m_AvailableMidiPins.end())
// {
// break;
// }

// // notify the device manager using the InstanceId for this midi device
// LOG_IF_FAILED(m_MidiDeviceManager->RemoveEndpoint(item->get()->InstanceId.c_str()));

// // remove the MIDI_PIN_INFO from the list
// m_AvailableMidiPins.erase(item);
//}
//while (TRUE);
// notify the device manager using the InstanceId for this midi device
LOG_IF_FAILED(m_MidiDeviceManager->RemoveEndpoint(device.Id().c_str()));

// remove the MIDI_PIN_INFO from the list
m_availableEndpointDefinitions.erase(item);
}
while (TRUE);

return S_OK;
}
Expand Down Expand Up @@ -669,8 +686,8 @@ CMidi2KSAggregateMidiEndpointManager::Cleanup()
m_DeviceStopped.revoke();
m_DeviceEnumerationCompleted.revoke();

m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();
m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();

return S_OK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ struct KsAggregateEndpointDefinition

MidiTransport TransportCapability;


std::vector<KsAggregateEndpointPinDefinition> Pins;
};

Expand Down Expand Up @@ -59,7 +58,7 @@ class CMidi2KSAggregateMidiEndpointManager :
wil::com_ptr_nothrow<IMidiDeviceManagerInterface> m_MidiDeviceManager;
wil::com_ptr_nothrow<IMidiEndpointProtocolManagerInterface> m_MidiProtocolManager;

//std::vector<std::unique_ptr<MIDI_PIN_INFO>> m_AvailableMidiPins;
std::vector<KsAggregateEndpointDefinition> m_availableEndpointDefinitions;

DeviceWatcher m_Watcher{0};
winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher<IDeviceWatcher>::Added_revoker m_DeviceAdded;
Expand Down