Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the dispatch failure when output C for addmm is not transposed #189

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions docker/pytorch-aarch64/patches/blas_to_mkl_acl.patch
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
# *******************************************************************************

diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index c658d4427c..5c792f0e73 100644
index a0531c50c96..55102c9d2f5 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1308,6 +1308,16 @@ static void addmm_impl_cpu_(
@@ -1420,6 +1420,20 @@ static void addmm_impl_cpu_(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
Expand All @@ -29,21 +29,25 @@ index c658d4427c..5c792f0e73 100644
+ // that will call then into ACL GEMM kernel and also additionaly have support
+ // for running kernel with BF16 instructions
+ if(transpose_a && !transpose_b) {
+ if (transpose_c) {
+ mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
+ return;
+ } else {
+ mkldnn_matmul(a, b, c, beta.to<float>(), alpha.to<float>());
+ }
+ return;
+ }
+ #endif
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpublas::gemm(
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any of the changes below this point are required.

index d41ebac635..e2cc13fe00 100644
index 383d2965923..b15056d7161 100644
--- a/aten/src/ATen/native/mkldnn/Matmul.cpp
+++ b/aten/src/ATen/native/mkldnn/Matmul.cpp
@@ -128,23 +128,25 @@ void mkldnn_matmul(
@@ -130,23 +130,25 @@ void mkldnn_matmul(
(mat1.dim() == 1 && mat2.dim() == 1), // aten::dot
"mkldnn_matmul: unsupported dims for mat and mat2");

+#if defined(__aarch64__)
+ // oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
+ // fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
Expand All @@ -58,7 +62,7 @@ index d41ebac635..e2cc13fe00 100644
+#else
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq, or AWS Graviton3");

-#if defined(__aarch64__)
- if (mkldnn_bf16_device_check_arm()) {
- //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
Expand All @@ -76,6 +80,6 @@ index d41ebac635..e2cc13fe00 100644
- mat2.scalar_type() == at::kBFloat16 &&
- result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
- }

auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;