Skip to content

Commit

Permalink
solve_bicgstab: use fewer MFs (#3635)
Browse files Browse the repository at this point in the history
## Summary

This PR cuts the number of MFs used in `solve_bicgstab`, saving on
memory and LocalCopy operations. In particular, the MFs `ph` and `sh`
are removed.

## Additional background

This is a follow up to avoid-use-of-s and other PRs to improve
`solve_bicgstab`. My own testing has shown that this PR gives the same
results as before, but regression testing should be done to verify this
in all cases.
  • Loading branch information
eebasso authored Nov 20, 2023
1 parent 175b99d commit d75c04b
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

const int ncomp = sol.nComp();

MF ph = Lp.make(amrlev, mglev, sol.nGrowVect());
MF sh = Lp.make(amrlev, mglev, sol.nGrowVect());
ph.setVal(RT(0.0));
sh.setVal(RT(0.0));
MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
MF r = Lp.make(amrlev, mglev, sol.nGrowVect());
p.setVal(RT(0.0)); // Make sure all entries are initialized to avoid errors
r.setVal(RT(0.0));

MF sorig = Lp.make(amrlev, mglev, nghost);
MF p = Lp.make(amrlev, mglev, nghost);
MF r = Lp.make(amrlev, mglev, nghost);
MF rh = Lp.make(amrlev, mglev, nghost);
MF v = Lp.make(amrlev, mglev, nghost);
MF t = Lp.make(amrlev, mglev, nghost);
Expand Down Expand Up @@ -151,8 +149,7 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
}
ph.LocalCopy(p,0,0,ncomp,nghost);
Lp.apply(amrlev, mglev, v, ph, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.apply(amrlev, mglev, v, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, v);

RT rhTv = dotxy(rh,v);
Expand All @@ -164,9 +161,10 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 2; break;
}
MF::Saxpy(sol, alpha, ph, 0, 0, ncomp, nghost); // sol += alpha * ph
MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
MF::Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v

rnorm = norm_inf(r);
rnorm = norm_inf(r);

if ( verbose > 2 && ParallelDescriptor::IOProcessor() )
Expand All @@ -179,8 +177,7 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }

sh.LocalCopy(r,0,0,ncomp,nghost);
Lp.apply(amrlev, mglev, t, sh, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.apply(amrlev, mglev, t, r, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, t);
//
// This is a little funky. I want to elide one of the reductions
Expand All @@ -201,8 +198,8 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 3; break;
}
MF::Saxpy(sol, omega, sh, 0, 0, ncomp, nghost); // sol += omega * sh
MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
MF::Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t

rnorm = norm_inf(r);

Expand Down

0 comments on commit d75c04b

Please sign in to comment.