From e000a89d68ed87a2bea24f10d4f4796d8412a95e Mon Sep 17 00:00:00 2001 From: Luc Berger-Vergiat Date: Thu, 19 Sep 2024 13:38:15 -0600 Subject: [PATCH] batched - dense: Testing and fixing Serial QR The serial QR algorithms does not have unit-tests and is failing for non square matrices. See issue #2328. This first commit fixes the issue with rectangular matrices and adds a basic test for that use case. Next will work on adding a test that exercises the interfaces on multiple matrices of different sizes within a parallel_for. Finally equivalent tests will be added for the square case as well. Signed-off-by: Luc --- .../impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp | 12 ++++++++---- .../dense/impl/KokkosBatched_QR_Serial_Internal.hpp | 2 +- batched/dense/unit_test/Test_Batched_Dense.hpp | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp index aaacb45ede..6443b7772d 100644 --- a/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp @@ -49,12 +49,16 @@ struct SerialQR_FormQ_Internal { /// B is m x m // set identity - if (is_Q_zero) - SerialSetInternal::invoke(m, value_type(1), Q, qs0 + qs1); - else + if (is_Q_zero) { + for (int idx = 0; idx < m; ++idx) { + Q[(qs0 + qs1) * idx] = value_type(1); + // SerialSetInternal::invoke(m, value_type(1), Q, qs0 + qs1); + } + } else { SerialSetIdentityInternal::invoke(m, Q, qs0, qs1); + } - return SerialApplyQ_LeftNoTransForwardInternal ::invoke(m, m, k, A, as0, as1, t, ts, Q, qs0, qs1, w); + return SerialApplyQ_LeftForwardInternal::invoke(m, m, k, A, as0, as1, t, ts, Q, qs0, qs1, w); } }; diff --git a/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp index 8aa4a6361c..851bd85ab9 100644 --- a/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp @@ -53,7 +53,7 @@ struct SerialQR_Internal { A_part2x2.partWithATL(A, m, n, 0, 0); t_part2x1.partWithAT(t, m, 0); - for (int m_atl = 0; m_atl < m; ++m_atl) { + for (int m_atl = 0; m_atl < Kokkos::min(m, n); ++m_atl) { // part 2x2 into 3x3 A_part3x3.partWithABR(A_part2x2, 1, 1); const int m_A22 = m - m_atl - 1; diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index 2378e5ff01..3a081b10cc 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -30,6 +30,7 @@ #include "Test_Batched_SerialLU.hpp" #include "Test_Batched_SerialLU_Real.hpp" #include "Test_Batched_SerialLU_Complex.hpp" +#include "Test_Batched_SerialQR.hpp" #include "Test_Batched_SerialSolveLU.hpp" #include "Test_Batched_SerialSolveLU_Real.hpp" #include "Test_Batched_SerialSolveLU_Complex.hpp"