diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 1403c99cc35cd..90e6516ff45d1 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -782,9 +782,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool has_zero_points = zero_points != nullptr; const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); + // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. + // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; if (accuracy_level_ == 4 && block_size == 32 && batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && has_subgroup && M >= kMinMForTileOptimization) { + !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { constexpr uint32_t kVec4Components = 4; constexpr uint32_t kVec2Components = 2; constexpr uint32_t kU32Components = 4;