From 4055b25b1186dab8e586beb909fa740efc894f62 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 27 Sep 2024 12:00:45 -0500 Subject: [PATCH] Remove dynamic cast from headers --- src/BLR/BLRMatrixMPI.cpp | 18 +++++++++++ src/BLR/BLRMatrixMPI.hpp | 16 ++-------- src/HSS/HSSMatrix.cpp | 10 ++++++ src/HSS/HSSMatrix.hpp | 8 ++--- src/HSS/HSSMatrixMPI.Schur.hpp | 57 --------------------------------- src/HSS/HSSMatrixMPI.cpp | 58 ++++++++++++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 77 deletions(-) diff --git a/src/BLR/BLRMatrixMPI.cpp b/src/BLR/BLRMatrixMPI.cpp index 473f139c..b7caccf4 100644 --- a/src/BLR/BLRMatrixMPI.cpp +++ b/src/BLR/BLRMatrixMPI.cpp @@ -139,6 +139,24 @@ namespace strumpack { return m; } + template DenseTile& + BLRMatrixMPI::ltile_dense(std::size_t i, std::size_t j) { + assert(i < rowblockslocal() && j < colblockslocal()); + assert(dynamic_cast*> + (blocks_[i+j*rowblockslocal()].get())); + return *static_cast*> + (blocks_[i+j*rowblockslocal()].get()); + } + + template const DenseTile& + BLRMatrixMPI::ltile_dense(std::size_t i, std::size_t j) const { + assert(i < rowblockslocal() && j < colblockslocal()); + assert(dynamic_cast*> + (blocks_[i+j*rowblockslocal()].get())); + return *static_cast*> + (blocks_[i+j*rowblockslocal()].get()); + } + template typename RealType::value_type BLRMatrixMPI::normF() const { diff --git a/src/BLR/BLRMatrixMPI.hpp b/src/BLR/BLRMatrixMPI.hpp index 5d82d3a4..703825f2 100644 --- a/src/BLR/BLRMatrixMPI.hpp +++ b/src/BLR/BLRMatrixMPI.hpp @@ -272,20 +272,8 @@ namespace strumpack { return *blocks_[i+j*rowblockslocal()].get(); } - DenseTile& ltile_dense(std::size_t i, std::size_t j) { - assert(i < rowblockslocal() && j < colblockslocal()); - assert(dynamic_cast*> - (blocks_[i+j*rowblockslocal()].get())); - return *static_cast*> - (blocks_[i+j*rowblockslocal()].get()); - } - const DenseTile& ltile_dense(std::size_t i, std::size_t j) const { - assert(i < rowblockslocal() && j < colblockslocal()); - assert(dynamic_cast*> - (blocks_[i+j*rowblockslocal()].get())); - return *static_cast*> - (blocks_[i+j*rowblockslocal()].get()); - } + DenseTile& ltile_dense(std::size_t i, std::size_t j); + const DenseTile& ltile_dense(std::size_t i, std::size_t j) const; std::unique_ptr>& block(std::size_t i, std::size_t j) { diff --git a/src/HSS/HSSMatrix.cpp b/src/HSS/HSSMatrix.cpp index 84b3af2d..4905bb73 100644 --- a/src/HSS/HSSMatrix.cpp +++ b/src/HSS/HSSMatrix.cpp @@ -138,6 +138,16 @@ namespace strumpack { (new HSSMatrix(*this)); } + template const HSSMatrix* + HSSMatrix::child(int c) const { + return dynamic_cast*>(this->ch_[c].get()); + } + + template HSSMatrix* + HSSMatrix::child(int c) { + return dynamic_cast*>(this->ch_[c].get()); + } + template void HSSMatrix::delete_trailing_block() { B01_.clear(); diff --git a/src/HSS/HSSMatrix.hpp b/src/HSS/HSSMatrix.hpp index 7ff75ef6..2b85797c 100644 --- a/src/HSS/HSSMatrix.hpp +++ b/src/HSS/HSSMatrix.hpp @@ -191,9 +191,7 @@ namespace strumpack { * matrix. The value of c should be 0 or 1, and this HSS matrix * should not be a leaf! */ - const HSSMatrix* child(int c) const { - return dynamic_cast*>(this->ch_[c].get()); - } + const HSSMatrix* child(int c) const; /** * Return a raw (non-owning) pointer to child c of this HSS @@ -201,9 +199,7 @@ namespace strumpack { * value of c should be 0 or 1, and this HSS matrix should not * be a leaf! */ - HSSMatrix* child(int c) { - return dynamic_cast*>(this->ch_[c].get()); - } + HSSMatrix* child(int c); /** * Initialize this HSS matrix as the compressed HSS diff --git a/src/HSS/HSSMatrixMPI.Schur.hpp b/src/HSS/HSSMatrixMPI.Schur.hpp index 26508417..4469e644 100644 --- a/src/HSS/HSSMatrixMPI.Schur.hpp +++ b/src/HSS/HSSMatrixMPI.Schur.hpp @@ -39,63 +39,6 @@ namespace strumpack { * Phi = (D0^{-1} * U0 * B01 * V1big^C)^C * = V1big * (D0^{-1} * U0 * B01)^C */ - template void HSSMatrixMPI::Schur_update - (DistM_t& Theta, DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const { - if (this->leaf()) return; - auto ch0 = child(0); - auto ch1 = child(1); - DistM_t DU(grid(), ch0->U_rows(), ch0->U_rank()); - if (auto ch0mpi = - dynamic_cast*>(child(0))) { - DistM_t chDU; - if (ch0mpi->active()) { - chDU = ch0->ULV_mpi_.D_.solve(ch0mpi->U_.dense(), ch0->ULV_mpi_.piv_); - STRUMPACK_SCHUR_FLOPS - (!ch0->ULV_mpi_.D_.is_master() ? 0 : - blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0mpi->U_.cols())); - } - copy(ch0->U_rows(), ch0->U_rank(), chDU, 0, 0, DU, 0, 0, grid()->ctxt_all()); - } else { - auto ch0seq = dynamic_cast*>(child(0)); - DenseM_t chDU; - if (ch0seq->active()) { - chDU = ch0->ULV_mpi_.D_.gather().solve - (ch0seq->U_.dense(), ch0->ULV_mpi_.piv_, ch0seq->openmp_task_depth_); - STRUMPACK_SCHUR_FLOPS - (!ch0->ULV_mpi_.D_.is_master() ? 0 : - blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0seq->U_.cols())); - } - copy(ch0->U_rows(), ch0->U_rank(), chDU, 0/*rank ch0*/, DU, 0, 0, grid()->ctxt_all()); - } - DUB01 = DistM_t(grid(), ch0->U_rows(), ch1->V_rank()); - gemm(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.), DUB01); - STRUMPACK_SCHUR_FLOPS - (gemm_flops(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.))); - - DistM_t _theta(ch1->grid(grid_local()), B10_.rows(), B10_.cols()); - copy(B10_.rows(), B10_.cols(), B10_, 0, 0, _theta, 0, 0, grid()->ctxt_all()); - auto DUB01t = DUB01.transpose(); - DistM_t _phi(ch1->grid(grid_local()), DUB01t.rows(), DUB01t.cols()); - copy(DUB01t.rows(), DUB01t.cols(), DUB01t, 0, 0, _phi, 0, 0, grid()->ctxt_all()); - DUB01t.clear(); - - DistSubLeaf Theta_br(_theta.cols(), ch1, grid_local()), - Phi_br(_phi.cols(), ch1, grid_local()); - DistM_t Theta_ch(ch1->grid(grid_local()), ch1->rows(), _theta.cols()); - DistM_t Phi_ch(ch1->grid(grid_local()), ch1->cols(), _phi.cols()); - long long int flops = 0; - ch1->apply_UV_big(Theta_br, _theta, Phi_br, _phi, flops); - STRUMPACK_SCHUR_FLOPS(flops); - Theta_br.from_block_row(Theta_ch); - Phi_br.from_block_row(Phi_ch); - Theta = DistM_t(grid(), Theta_ch.rows(), Theta_ch.cols()); - Phi = DistM_t(grid(), Phi_ch.rows(), Phi_ch.cols()); - copy(Theta.rows(), Theta.cols(), Theta_ch, 0, 0, Theta, 0, 0, grid()->ctxt_all()); - copy(Phi.rows(), Phi.cols(), Phi_ch, 0, 0, Phi, 0, 0, grid()->ctxt_all()); - - Vhat = DistM_t(grid(), Phi.cols(), Theta.cols()); - copy(Vhat.rows(), Vhat.cols(), ch0->ULV_mpi_.Vhat(), 0, 0, Vhat, 0, 0, grid()->ctxt_all()); - } /** * Apply Schur complement the direct way: diff --git a/src/HSS/HSSMatrixMPI.cpp b/src/HSS/HSSMatrixMPI.cpp index 6df01069..d40c8e8b 100644 --- a/src/HSS/HSSMatrixMPI.cpp +++ b/src/HSS/HSSMatrixMPI.cpp @@ -170,6 +170,64 @@ namespace strumpack { setup_ranges(roff, coff); } + template void HSSMatrixMPI::Schur_update + (DistM_t& Theta, DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const { + if (this->leaf()) return; + auto ch0 = child(0); + auto ch1 = child(1); + DistM_t DU(grid(), ch0->U_rows(), ch0->U_rank()); + if (auto ch0mpi = + dynamic_cast*>(child(0))) { + DistM_t chDU; + if (ch0mpi->active()) { + chDU = ch0->ULV_mpi_.D_.solve(ch0mpi->U_.dense(), ch0->ULV_mpi_.piv_); + STRUMPACK_SCHUR_FLOPS + (!ch0->ULV_mpi_.D_.is_master() ? 0 : + blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0mpi->U_.cols())); + } + copy(ch0->U_rows(), ch0->U_rank(), chDU, 0, 0, DU, 0, 0, grid()->ctxt_all()); + } else { + auto ch0seq = dynamic_cast*>(child(0)); + DenseM_t chDU; + if (ch0seq->active()) { + chDU = ch0->ULV_mpi_.D_.gather().solve + (ch0seq->U_.dense(), ch0->ULV_mpi_.piv_, ch0seq->openmp_task_depth_); + STRUMPACK_SCHUR_FLOPS + (!ch0->ULV_mpi_.D_.is_master() ? 0 : + blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0seq->U_.cols())); + } + copy(ch0->U_rows(), ch0->U_rank(), chDU, 0/*rank ch0*/, DU, 0, 0, grid()->ctxt_all()); + } + DUB01 = DistM_t(grid(), ch0->U_rows(), ch1->V_rank()); + gemm(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.), DUB01); + STRUMPACK_SCHUR_FLOPS + (gemm_flops(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.))); + + DistM_t _theta(ch1->grid(grid_local()), B10_.rows(), B10_.cols()); + copy(B10_.rows(), B10_.cols(), B10_, 0, 0, _theta, 0, 0, grid()->ctxt_all()); + auto DUB01t = DUB01.transpose(); + DistM_t _phi(ch1->grid(grid_local()), DUB01t.rows(), DUB01t.cols()); + copy(DUB01t.rows(), DUB01t.cols(), DUB01t, 0, 0, _phi, 0, 0, grid()->ctxt_all()); + DUB01t.clear(); + + DistSubLeaf Theta_br(_theta.cols(), ch1, grid_local()), + Phi_br(_phi.cols(), ch1, grid_local()); + DistM_t Theta_ch(ch1->grid(grid_local()), ch1->rows(), _theta.cols()); + DistM_t Phi_ch(ch1->grid(grid_local()), ch1->cols(), _phi.cols()); + long long int flops = 0; + ch1->apply_UV_big(Theta_br, _theta, Phi_br, _phi, flops); + STRUMPACK_SCHUR_FLOPS(flops); + Theta_br.from_block_row(Theta_ch); + Phi_br.from_block_row(Phi_ch); + Theta = DistM_t(grid(), Theta_ch.rows(), Theta_ch.cols()); + Phi = DistM_t(grid(), Phi_ch.rows(), Phi_ch.cols()); + copy(Theta.rows(), Theta.cols(), Theta_ch, 0, 0, Theta, 0, 0, grid()->ctxt_all()); + copy(Phi.rows(), Phi.cols(), Phi_ch, 0, 0, Phi, 0, 0, grid()->ctxt_all()); + + Vhat = DistM_t(grid(), Phi.cols(), Theta.cols()); + copy(Vhat.rows(), Vhat.cols(), ch0->ULV_mpi_.Vhat(), 0, 0, Vhat, 0, 0, grid()->ctxt_all()); + } + template void HSSMatrixMPI::setup_local_context() { if (!this->leaf()) {