Skip to content

Commit

Permalink
Remove dynamic cast from headers
Browse files Browse the repository at this point in the history
  • Loading branch information
AmperesAvengement committed Sep 27, 2024
1 parent b964011 commit 4055b25
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 77 deletions.
18 changes: 18 additions & 0 deletions src/BLR/BLRMatrixMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,24 @@ namespace strumpack {
return m;
}

template<typename scalar_t> DenseTile<scalar_t>&
BLRMatrixMPI<scalar_t>::ltile_dense(std::size_t i, std::size_t j) {
assert(i < rowblockslocal() && j < colblockslocal());
assert(dynamic_cast<DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get()));
return *static_cast<DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get());
}

template<typename scalar_t> const DenseTile<scalar_t>&
BLRMatrixMPI<scalar_t>::ltile_dense(std::size_t i, std::size_t j) const {
assert(i < rowblockslocal() && j < colblockslocal());
assert(dynamic_cast<const DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get()));
return *static_cast<const DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get());
}

template<typename scalar_t>
typename RealType<scalar_t>::value_type
BLRMatrixMPI<scalar_t>::normF() const {
Expand Down
16 changes: 2 additions & 14 deletions src/BLR/BLRMatrixMPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,20 +272,8 @@ namespace strumpack {
return *blocks_[i+j*rowblockslocal()].get();
}

DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j) {
assert(i < rowblockslocal() && j < colblockslocal());
assert(dynamic_cast<DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get()));
return *static_cast<DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get());
}
const DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j) const {
assert(i < rowblockslocal() && j < colblockslocal());
assert(dynamic_cast<const DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get()));
return *static_cast<const DenseTile<scalar_t>*>
(blocks_[i+j*rowblockslocal()].get());
}
DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j);
const DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j) const;

std::unique_ptr<BLRTile<scalar_t>>&
block(std::size_t i, std::size_t j) {
Expand Down
10 changes: 10 additions & 0 deletions src/HSS/HSSMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ namespace strumpack {
(new HSSMatrix<scalar_t>(*this));
}

template<typename scalar_t> const HSSMatrix<scalar_t>*
HSSMatrix<scalar_t>::child(int c) const {
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
}

template<typename scalar_t> HSSMatrix<scalar_t>*
HSSMatrix<scalar_t>::child(int c) {
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
}

template<typename scalar_t> void
HSSMatrix<scalar_t>::delete_trailing_block() {
B01_.clear();
Expand Down
8 changes: 2 additions & 6 deletions src/HSS/HSSMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,15 @@ namespace strumpack {
* matrix. The value of c should be 0 or 1, and this HSS matrix
* should not be a leaf!
*/
const HSSMatrix<scalar_t>* child(int c) const {
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
}
const HSSMatrix<scalar_t>* child(int c) const;

/**
* Return a raw (non-owning) pointer to child c of this HSS
* matrix. A child of an HSS matrix is itself an HSS matrix. The
* value of c should be 0 or 1, and this HSS matrix should not
* be a leaf!
*/
HSSMatrix<scalar_t>* child(int c) {
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
}
HSSMatrix<scalar_t>* child(int c);

/**
* Initialize this HSS matrix as the compressed HSS
Expand Down
57 changes: 0 additions & 57 deletions src/HSS/HSSMatrixMPI.Schur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,63 +39,6 @@ namespace strumpack {
* Phi = (D0^{-1} * U0 * B01 * V1big^C)^C
* = V1big * (D0^{-1} * U0 * B01)^C
*/
template<typename scalar_t> void HSSMatrixMPI<scalar_t>::Schur_update
(DistM_t& Theta, DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const {
if (this->leaf()) return;
auto ch0 = child(0);
auto ch1 = child(1);
DistM_t DU(grid(), ch0->U_rows(), ch0->U_rank());
if (auto ch0mpi =
dynamic_cast<const HSSMatrixMPI<scalar_t>*>(child(0))) {
DistM_t chDU;
if (ch0mpi->active()) {
chDU = ch0->ULV_mpi_.D_.solve(ch0mpi->U_.dense(), ch0->ULV_mpi_.piv_);
STRUMPACK_SCHUR_FLOPS
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0mpi->U_.cols()));
}
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0, 0, DU, 0, 0, grid()->ctxt_all());
} else {
auto ch0seq = dynamic_cast<const HSSMatrix<scalar_t>*>(child(0));
DenseM_t chDU;
if (ch0seq->active()) {
chDU = ch0->ULV_mpi_.D_.gather().solve
(ch0seq->U_.dense(), ch0->ULV_mpi_.piv_, ch0seq->openmp_task_depth_);
STRUMPACK_SCHUR_FLOPS
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0seq->U_.cols()));
}
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0/*rank ch0*/, DU, 0, 0, grid()->ctxt_all());
}
DUB01 = DistM_t(grid(), ch0->U_rows(), ch1->V_rank());
gemm(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.), DUB01);
STRUMPACK_SCHUR_FLOPS
(gemm_flops(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.)));

DistM_t _theta(ch1->grid(grid_local()), B10_.rows(), B10_.cols());
copy(B10_.rows(), B10_.cols(), B10_, 0, 0, _theta, 0, 0, grid()->ctxt_all());
auto DUB01t = DUB01.transpose();
DistM_t _phi(ch1->grid(grid_local()), DUB01t.rows(), DUB01t.cols());
copy(DUB01t.rows(), DUB01t.cols(), DUB01t, 0, 0, _phi, 0, 0, grid()->ctxt_all());
DUB01t.clear();

DistSubLeaf<scalar_t> Theta_br(_theta.cols(), ch1, grid_local()),
Phi_br(_phi.cols(), ch1, grid_local());
DistM_t Theta_ch(ch1->grid(grid_local()), ch1->rows(), _theta.cols());
DistM_t Phi_ch(ch1->grid(grid_local()), ch1->cols(), _phi.cols());
long long int flops = 0;
ch1->apply_UV_big(Theta_br, _theta, Phi_br, _phi, flops);
STRUMPACK_SCHUR_FLOPS(flops);
Theta_br.from_block_row(Theta_ch);
Phi_br.from_block_row(Phi_ch);
Theta = DistM_t(grid(), Theta_ch.rows(), Theta_ch.cols());
Phi = DistM_t(grid(), Phi_ch.rows(), Phi_ch.cols());
copy(Theta.rows(), Theta.cols(), Theta_ch, 0, 0, Theta, 0, 0, grid()->ctxt_all());
copy(Phi.rows(), Phi.cols(), Phi_ch, 0, 0, Phi, 0, 0, grid()->ctxt_all());

Vhat = DistM_t(grid(), Phi.cols(), Theta.cols());
copy(Vhat.rows(), Vhat.cols(), ch0->ULV_mpi_.Vhat(), 0, 0, Vhat, 0, 0, grid()->ctxt_all());
}

/**
* Apply Schur complement the direct way:
Expand Down
58 changes: 58 additions & 0 deletions src/HSS/HSSMatrixMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,64 @@ namespace strumpack {
setup_ranges(roff, coff);
}

template<typename scalar_t> void HSSMatrixMPI<scalar_t>::Schur_update
(DistM_t& Theta, DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const {
if (this->leaf()) return;
auto ch0 = child(0);
auto ch1 = child(1);
DistM_t DU(grid(), ch0->U_rows(), ch0->U_rank());
if (auto ch0mpi =
dynamic_cast<const HSSMatrixMPI<scalar_t>*>(child(0))) {
DistM_t chDU;
if (ch0mpi->active()) {
chDU = ch0->ULV_mpi_.D_.solve(ch0mpi->U_.dense(), ch0->ULV_mpi_.piv_);
STRUMPACK_SCHUR_FLOPS
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0mpi->U_.cols()));
}
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0, 0, DU, 0, 0, grid()->ctxt_all());
} else {
auto ch0seq = dynamic_cast<const HSSMatrix<scalar_t>*>(child(0));
DenseM_t chDU;
if (ch0seq->active()) {
chDU = ch0->ULV_mpi_.D_.gather().solve
(ch0seq->U_.dense(), ch0->ULV_mpi_.piv_, ch0seq->openmp_task_depth_);
STRUMPACK_SCHUR_FLOPS
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0seq->U_.cols()));
}
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0/*rank ch0*/, DU, 0, 0, grid()->ctxt_all());
}
DUB01 = DistM_t(grid(), ch0->U_rows(), ch1->V_rank());
gemm(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.), DUB01);
STRUMPACK_SCHUR_FLOPS
(gemm_flops(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.)));

DistM_t _theta(ch1->grid(grid_local()), B10_.rows(), B10_.cols());
copy(B10_.rows(), B10_.cols(), B10_, 0, 0, _theta, 0, 0, grid()->ctxt_all());
auto DUB01t = DUB01.transpose();
DistM_t _phi(ch1->grid(grid_local()), DUB01t.rows(), DUB01t.cols());
copy(DUB01t.rows(), DUB01t.cols(), DUB01t, 0, 0, _phi, 0, 0, grid()->ctxt_all());
DUB01t.clear();

DistSubLeaf<scalar_t> Theta_br(_theta.cols(), ch1, grid_local()),
Phi_br(_phi.cols(), ch1, grid_local());
DistM_t Theta_ch(ch1->grid(grid_local()), ch1->rows(), _theta.cols());
DistM_t Phi_ch(ch1->grid(grid_local()), ch1->cols(), _phi.cols());
long long int flops = 0;
ch1->apply_UV_big(Theta_br, _theta, Phi_br, _phi, flops);
STRUMPACK_SCHUR_FLOPS(flops);
Theta_br.from_block_row(Theta_ch);
Phi_br.from_block_row(Phi_ch);
Theta = DistM_t(grid(), Theta_ch.rows(), Theta_ch.cols());
Phi = DistM_t(grid(), Phi_ch.rows(), Phi_ch.cols());
copy(Theta.rows(), Theta.cols(), Theta_ch, 0, 0, Theta, 0, 0, grid()->ctxt_all());
copy(Phi.rows(), Phi.cols(), Phi_ch, 0, 0, Phi, 0, 0, grid()->ctxt_all());

Vhat = DistM_t(grid(), Phi.cols(), Theta.cols());
copy(Vhat.rows(), Vhat.cols(), ch0->ULV_mpi_.Vhat(), 0, 0, Vhat, 0, 0, grid()->ctxt_all());
}

template<typename scalar_t> void
HSSMatrixMPI<scalar_t>::setup_local_context() {
if (!this->leaf()) {
Expand Down

0 comments on commit 4055b25

Please sign in to comment.