Skip to content

Commit

Permalink
Try again to fix COLWISE
Browse files Browse the repository at this point in the history
  • Loading branch information
pghysels committed Jul 19, 2024
1 parent e7f2c85 commit 76c45e6
Showing 1 changed file with 32 additions and 34 deletions.
66 changes: 32 additions & 34 deletions src/BLR/BLRMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,61 +1422,59 @@ namespace strumpack {
trsm(Side::L, UpLo::L, Trans::N, Diag::U,
scalar_t(1.), B11.tile(k, k), B11.tile(k, i));
#pragma omp taskloop
for (std::size_t lk=k+1; lk<rb+rb2; lk++)
if (lk < rb)
gemm(Trans::N, Trans::N, scalar_t(-1.),
B11.tile(lk, k), B11.tile(k, i), scalar_t(1.),
B11.tile_dense(lk, i).D());
else
gemm(Trans::N, Trans::N, scalar_t(-1.),
B21.tile(lk-rb, k), B11.tile(k, i), scalar_t(1.),
B21.tile_dense(lk-rb, i).D());
for (std::size_t j=k+1; j<rb; j++)
gemm(Trans::N, Trans::N, scalar_t(-1.),
B11.tile(j, k), B11.tile(k, i), scalar_t(1.),
B11.tile_dense(j, i).D());
}
auto tpiv = B11.tile(i, i).LU(opts.pivot_threshold());
std::copy(tpiv.begin(), tpiv.end(),
B11.piv_.begin()+B11.tileroff(i));
#pragma omp taskloop
for (std::size_t j=i+1; j<rb+rb2; j++)
if (j < rb) {
for (std::size_t j=i+1; j<rb; j++) {
trsm(Side::R, UpLo::U, Trans::N, Diag::N,
scalar_t(1.), B11.tile(i, i), B11.tile(j, i));
if (admissible(j, i))
B11.compress_tile(j, i, opts);
trsm(Side::R, UpLo::U, Trans::N, Diag::N,
scalar_t(1.), B11.tile(i, i), B11.tile(j, i));
} else {
B21.compress_tile(j-rb, i, opts);
trsm(Side::R, UpLo::U, Trans::N, Diag::N,
scalar_t(1.), B11.tile(i, i), B21.tile(j-rb, i));
}
}
#pragma omp taskloop
for (std::size_t j=0; j<rb2; j++) {
for (std::size_t k=0; k<i; k++)
gemm(Trans::N, Trans::N, scalar_t(-1.),
B21.tile(j, k), B11.tile(k, i), scalar_t(1.),
B21.tile_dense(j, i).D());
B21.compress_tile(j, i, opts);
trsm(Side::R, UpLo::U, Trans::N, Diag::N,
scalar_t(1.), B11.tile(i, i), B21.tile(j, i));
}
}
for (std::size_t i=0; i<rb2; i+=CP) { // F12 and F22
B12.fill_col(0., i, CP);
B22.fill_col(0., i, CP);
blockcol(i, false, CP);
for (std::size_t k=0; k<rb; k++) {
#pragma omp taskloop
for (std::size_t lk=k+1; lk<rb+rb2; lk++)
if (lk < rb)
gemm(Trans::N, Trans::N, scalar_t(-1.),
B11.tile(lk, k), B12.tile(k, i), scalar_t(1.),
B12.tile_dense(lk, i).D());
else
gemm(Trans::N, Trans::N, scalar_t(-1.),
B21.tile(lk-rb, k), B12.tile(k, i), scalar_t(1.),
B22.tile_dense(lk-rb, i).D());
}
#pragma omp taskloop
for (std::size_t k=0; k<rb; k++) {
B12.compress_tile(k, i, opts);
std::vector<int> tpiv
(B11.piv_.begin()+B11.tileroff(k),
B11.piv_.begin()+B11.tileroff(k+1));
B12.tile(k, i).laswp(tpiv, true);
trsm(Side::L, UpLo::L, Trans::N, Diag::U,
scalar_t(1.), B11.tile(k, k), B12.tile(k, i));
#pragma omp taskloop
for (std::size_t j=k+1; j<rb; j++)
gemm(Trans::N, Trans::N, scalar_t(-1.),
B11.tile(j, k), B12.tile(k, i), scalar_t(1.),
B12.tile_dense(j, i).D());
}
#pragma omp taskloop
for (std::size_t k=0; k<rb2; k++)
if (i != k)
B22.compress_tile(k, i, opts);
for (std::size_t j=0; j<rb2; j++) {
for (std::size_t k=0; k<rb; k++)
gemm(Trans::N, Trans::N, scalar_t(-1.),
B21.tile(j, k), B12.tile(k, i), scalar_t(1.),
B22.tile_dense(j, i).D());
if (j != i)
B22.compress_tile(j, i, opts);
}
}
}
for (std::size_t i=0; i<rb; i++)
Expand Down

0 comments on commit 76c45e6

Please sign in to comment.