Skip to content

Commit

Permalink
Merge pull request #408 from microsoft/pete-dev
Browse files Browse the repository at this point in the history
KSA device removal handler
  • Loading branch information
Psychlist1972 authored Oct 8, 2024
2 parents ee1c058 + b6d3740 commit f759f56
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 55 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified build/staging/app-sdk/Arm64EC/mididiag.exe
Binary file not shown.
2 changes: 1 addition & 1 deletion build/staging/version/BundleInfo.wxi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Include>
<?define SetupVersionName="Developer Preview 7 Arm64" ?>
<?define SetupVersionNumber="1.0.24260.2222" ?>
<?define SetupVersionNumber="1.0.24281.1756" ?>
</Include>
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ CMidi2KSMidiEndpointManager::Cleanup()
m_DeviceStopped.revoke();
m_DeviceEnumerationCompleted.revoke();

m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();
m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();

return S_OK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ CMidi2KSAggregateMidiEndpointManager::Initialize(
MIDI_TRACE_EVENT_INFO,
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this")
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD)
);

RETURN_HR_IF(E_INVALIDARG, nullptr == midiDeviceManager);
Expand Down Expand Up @@ -86,7 +87,7 @@ CMidi2KSAggregateMidiEndpointManager::Initialize(
_Use_decl_annotations_
HRESULT
CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
KsAggregateEndpointDefinition& MasterEndpointDefinition
KsAggregateEndpointDefinition& masterEndpointDefinition
)
{
TraceLoggingWrite(
Expand All @@ -96,21 +97,24 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Creating aggregate UMP endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name")
);
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(masterEndpointDefinition.EndpointDeviceInstanceId.c_str(), "Endpoint Device Instance Id"),
TraceLoggingWideString(masterEndpointDefinition.FilterDeviceId.c_str(), "Filter Device Id"),
TraceLoggingWideString(masterEndpointDefinition.ParentDeviceInstanceId.c_str(), "Parent Device Instance Id")
);


// we require at least one valid pin
RETURN_HR_IF(E_INVALIDARG, MasterEndpointDefinition.Pins.size() < 1);
RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition.Pins.size() < 1);

std::vector<DEVPROPERTY> interfaceDevProperties;

MIDIENDPOINTCOMMONPROPERTIES commonProperties{};
commonProperties.AbstractionLayerGuid = ABSTRACTION_LAYER_GUID;
commonProperties.EndpointPurpose = MidiEndpointDevicePurposePropertyValue::NormalMessageEndpoint;
commonProperties.FriendlyName = MasterEndpointDefinition.EndpointName.c_str();
commonProperties.FriendlyName = masterEndpointDefinition.EndpointName.c_str();
commonProperties.TransportCode = TRANSPORT_CODE;
commonProperties.TransportSuppliedEndpointName = MasterEndpointDefinition.FilterName.c_str();
commonProperties.TransportSuppliedEndpointName = masterEndpointDefinition.FilterName.c_str();
commonProperties.TransportSuppliedEndpointDescription = nullptr;
commonProperties.UserSuppliedEndpointName = nullptr;
commonProperties.UserSuppliedEndpointDescription = nullptr;
Expand All @@ -126,10 +130,10 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
commonProperties.SupportsMidi2ProtocolDefaultValue = false;

interfaceDevProperties.push_back({ {DEVPKEY_KsMidiPort_KsFilterInterfaceId, DEVPROP_STORE_SYSTEM, nullptr},
DEVPROP_TYPE_STRING, static_cast<ULONG>((MasterEndpointDefinition.FilterDeviceId.length() + 1) * sizeof(WCHAR)), (PVOID)MasterEndpointDefinition.FilterDeviceId.c_str() });
DEVPROP_TYPE_STRING, static_cast<ULONG>((masterEndpointDefinition.FilterDeviceId.length() + 1) * sizeof(WCHAR)), (PVOID)masterEndpointDefinition.FilterDeviceId.c_str() });

interfaceDevProperties.push_back({ {DEVPKEY_KsTransport, DEVPROP_STORE_SYSTEM, nullptr },
DEVPROP_TYPE_UINT32, static_cast<ULONG>(sizeof(UINT32)), (PVOID)&MasterEndpointDefinition.TransportCapability });
DEVPROP_TYPE_UINT32, static_cast<ULONG>(sizeof(UINT32)), (PVOID)&masterEndpointDefinition.TransportCapability });

// create group terminal blocks and the pin map

Expand All @@ -140,7 +144,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
KSMIDI_PIN_MAP pinMap{ };
std::vector<internal::GroupTerminalBlockInternal> blocks;

for (auto const& pin : MasterEndpointDefinition.Pins)
for (auto const& pin : masterEndpointDefinition.Pins)
{
internal::GroupTerminalBlockInternal gtb;

Expand Down Expand Up @@ -218,9 +222,9 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
SW_DEVICE_CREATE_INFO createInfo{ };

createInfo.cbSize = sizeof(createInfo);
createInfo.pszInstanceId = MasterEndpointDefinition.EndpointDeviceInstanceId.c_str();
createInfo.pszInstanceId = masterEndpointDefinition.EndpointDeviceInstanceId.c_str();
createInfo.CapabilityFlags = SWDeviceCapabilitiesNone;
createInfo.pszDeviceDescription = MasterEndpointDefinition.EndpointName.c_str();
createInfo.pszDeviceDescription = masterEndpointDefinition.EndpointName.c_str();

// Call the device manager and finish the creation

Expand All @@ -231,7 +235,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(

LOG_IF_FAILED(
swdCreationResult = m_MidiDeviceManager->ActivateEndpoint(
MasterEndpointDefinition.ParentDeviceInstanceId.c_str(),
masterEndpointDefinition.ParentDeviceInstanceId.c_str(),
false, // TODO: create UMP only, handle the MIDI 1.0 compat ports
MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint
&commonProperties,
Expand All @@ -254,13 +258,18 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(newDeviceInterfaceId, "endpoint id")
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(newDeviceInterfaceId, MIDI_TRACE_EVENT_DEVICE_SWD_ID_FIELD)
);

// TODO: Add to internal endpoint manager, and also return the device interface id
// todo: return new device interface id


return swdCreationResult; // TODO change this to account for other steps

// Add to internal endpoint manager
m_availableEndpointDefinitions.push_back(std::move(masterEndpointDefinition));

return swdCreationResult;
}
else
{
Expand All @@ -271,7 +280,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD)
);

Expand Down Expand Up @@ -358,7 +367,7 @@ CMidi2KSAggregateMidiEndpointManager::OnDeviceAdded(
auto prop = properties.Lookup(winrt::to_hstring(L"System.Devices.DeviceInstanceId"));
RETURN_HR_IF_NULL(E_INVALIDARG, prop);
endpointDefinition.ParentDeviceInstanceId = winrt::unbox_value<winrt::hstring>(prop).c_str();
endpointDefinition.FilterDeviceId = device.Id().c_str();
endpointDefinition.FilterDeviceId = internal::NormalizeDeviceInstanceIdWStringCopy(device.Id().c_str());

// get the parent device
auto parentDeviceInfo = DeviceInformation::CreateFromIdAsync(endpointDefinition.ParentDeviceInstanceId,
Expand Down Expand Up @@ -584,43 +593,51 @@ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved(DeviceWatcher, Dev
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Enter", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(device.Id().c_str(), "device id")
);

auto cleanDeviceId = internal::NormalizeDeviceInstanceIdWStringCopy(device.Id().c_str());

// TODO
// the interface is no longer active, search through our m_AvailableMidiPins to identify
// every entry with this filter interface id, and remove the SWD and remove the pin(s) from
// the m_AvailableMidiPins list.
do
{
auto item = std::find_if(m_availableEndpointDefinitions.begin(),
m_availableEndpointDefinitions.end(), [&](const KsAggregateEndpointDefinition& endpointDefinition)
{
if (endpointDefinition.FilterDeviceId == cleanDeviceId)
{
return true;
}

return false;
});


if (item == m_availableEndpointDefinitions.end())
{
break;
}

TraceLoggingWrite(
MidiKSAggregateAbstractionTelemetryProvider::Provider(),
MIDI_TRACE_EVENT_INFO,
TraceLoggingString(__FUNCTION__, MIDI_TRACE_EVENT_LOCATION_FIELD),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Found device to remove", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(device.Id().c_str(), "device id")
);

// the interface is no longer active, search through our m_AvailableMidiPins to identify
// every entry with this filter interface id, and remove the SWD and remove the pin(s) from
// the m_AvailableMidiPins list.
//do
//{
// auto item = std::find_if(m_AvailableMidiPins.begin(), m_AvailableMidiPins.end(), [&](const std::unique_ptr<MIDI_PIN_INFO>& Pin)
// {
// // if this interface id already activated, then we cannot activate/create a second time,
// if (device.Id() == Pin->Id)
// {
// return true;
// }

// return false;
// });

// if (item == m_AvailableMidiPins.end())
// {
// break;
// }

// // notify the device manager using the InstanceId for this midi device
// LOG_IF_FAILED(m_MidiDeviceManager->RemoveEndpoint(item->get()->InstanceId.c_str()));

// // remove the MIDI_PIN_INFO from the list
// m_AvailableMidiPins.erase(item);
//}
//while (TRUE);
// notify the device manager using the InstanceId for this midi device
LOG_IF_FAILED(m_MidiDeviceManager->RemoveEndpoint(device.Id().c_str()));

// remove the MIDI_PIN_INFO from the list
m_availableEndpointDefinitions.erase(item);
}
while (TRUE);

return S_OK;
}
Expand Down Expand Up @@ -669,8 +686,8 @@ CMidi2KSAggregateMidiEndpointManager::Cleanup()
m_DeviceStopped.revoke();
m_DeviceEnumerationCompleted.revoke();

m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();
m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();

return S_OK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ struct KsAggregateEndpointDefinition

MidiTransport TransportCapability;


std::vector<KsAggregateEndpointPinDefinition> Pins;
};

Expand Down Expand Up @@ -59,7 +58,7 @@ class CMidi2KSAggregateMidiEndpointManager :
wil::com_ptr_nothrow<IMidiDeviceManagerInterface> m_MidiDeviceManager;
wil::com_ptr_nothrow<IMidiEndpointProtocolManagerInterface> m_MidiProtocolManager;

//std::vector<std::unique_ptr<MIDI_PIN_INFO>> m_AvailableMidiPins;
std::vector<KsAggregateEndpointDefinition> m_availableEndpointDefinitions;

DeviceWatcher m_Watcher{0};
winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher<IDeviceWatcher>::Added_revoker m_DeviceAdded;
Expand Down

0 comments on commit f759f56

Please sign in to comment.