diff --git a/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp index aaacb45ede..6443b7772d 100644 --- a/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_QR_FormQ_Serial_Internal.hpp @@ -49,12 +49,16 @@ struct SerialQR_FormQ_Internal { /// B is m x m // set identity - if (is_Q_zero) - SerialSetInternal::invoke(m, value_type(1), Q, qs0 + qs1); - else + if (is_Q_zero) { + for (int idx = 0; idx < m; ++idx) { + Q[(qs0 + qs1) * idx] = value_type(1); + // SerialSetInternal::invoke(m, value_type(1), Q, qs0 + qs1); + } + } else { SerialSetIdentityInternal::invoke(m, Q, qs0, qs1); + } - return SerialApplyQ_LeftNoTransForwardInternal ::invoke(m, m, k, A, as0, as1, t, ts, Q, qs0, qs1, w); + return SerialApplyQ_LeftForwardInternal::invoke(m, m, k, A, as0, as1, t, ts, Q, qs0, qs1, w); } }; diff --git a/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp index 8aa4a6361c..851bd85ab9 100644 --- a/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_QR_Serial_Internal.hpp @@ -53,7 +53,7 @@ struct SerialQR_Internal { A_part2x2.partWithATL(A, m, n, 0, 0); t_part2x1.partWithAT(t, m, 0); - for (int m_atl = 0; m_atl < m; ++m_atl) { + for (int m_atl = 0; m_atl < Kokkos::min(m, n); ++m_atl) { // part 2x2 into 3x3 A_part3x3.partWithABR(A_part2x2, 1, 1); const int m_A22 = m - m_atl - 1; diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index 2378e5ff01..3a081b10cc 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -30,6 +30,7 @@ #include "Test_Batched_SerialLU.hpp" #include "Test_Batched_SerialLU_Real.hpp" #include "Test_Batched_SerialLU_Complex.hpp" +#include "Test_Batched_SerialQR.hpp" #include "Test_Batched_SerialSolveLU.hpp" #include "Test_Batched_SerialSolveLU_Real.hpp" #include "Test_Batched_SerialSolveLU_Complex.hpp"