Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Jan 24, 2025
1 parent 9842684 commit ebf3d64
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 103 deletions.
4 changes: 2 additions & 2 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1));

if (m_r2x) {
m_r2x->template forwardThenBackward_doit<0>(rhs, soln, f, ng, m_geom.periodicity());
m_r2x->template forwardThenBackward_doit_0(rhs, soln, f, ng, m_geom.periodicity());
detail::fill_physbc(soln, m_geom, m_bc);
} else {
m_r2c->forward(rhs);
m_r2c->template post_forward_doit<0>(f);
m_r2c->template post_forward_doit_0(f);
m_r2c->backward_doit(soln, ng, m_geom.periodicity());
}
}
Expand Down
102 changes: 55 additions & 47 deletions Src/FFT/AMReX_FFT_R2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public:
{
BL_PROFILE("FFT::R2C::forwardbackward");
this->forward(inmf);
this->post_forward_doit<0>(post_forward);
this->post_forward_doit_0(post_forward);
this->backward(outmf);
}

Expand Down Expand Up @@ -165,8 +165,11 @@ public:
[[nodiscard]] std::pair<BoxArray,DistributionMapping> getSpectralDataLayout () const;

// This is a private function, but it's public for cuda.
template <int Depth, typename F>
void post_forward_doit (F const& post_forward);
template <typename F>
void post_forward_doit_0 (F const& post_forward);

template <typename F>
void post_forward_doit_1 (F const& post_forward);

private:

Expand Down Expand Up @@ -665,63 +668,68 @@ R2C<T,D,S>::make_c2c_plans (cMF& inout)
}

template <typename T, Direction D, DomainStrategy S>
template <int Depth, typename F>
void R2C<T,D,S>::post_forward_doit (F const& post_forward)
template <typename F>
void R2C<T,D,S>::post_forward_doit_0 (F const& post_forward)
{
if (m_info.batch_mode) {
amrex::Abort("xxxxx todo: post_forward");
#if (AMREX_SPACEDIM > 1)
} else if (m_r2c_sub) {
#if defined(AMREX_USE_CUDA) && defined(_WIN32)
if (Depth == 0) {
#else
if constexpr (Depth == 0) {
#endif
// We need to pass the originally ordered indices to post_forward.
// We need to pass the originally ordered indices to post_forward.
#if (AMREX_SPACEDIM == 2)
// The original domain is (1,ny). The sub domain is (ny,1).
m_r2c_sub->template post_forward_doit<(Depth+1)>
// The original domain is (1,ny). The sub domain is (ny,1).
m_r2c_sub->template post_forward_doit_1
([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp)
{
post_forward(0, i, 0, sp);
});
#else
if (m_real_domain.length(0) == 1 && m_real_domain.length(1) == 1) {
// Original domain: (1, 1, nz). Sub domain: (nz, 1, 1)
m_r2c_sub->template post_forward_doit_1
([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp)
{
post_forward(0, 0, i, sp);
});
} else if (m_real_domain.length(0) == 1 && m_real_domain.length(2) == 1) {
// Original domain: (1, ny, 1). Sub domain: (ny, 1, 1)
m_r2c_sub->template post_forward_doit_1
([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp)
{
post_forward(0, i, 0, sp);
});
#else
if (m_real_domain.length(0) == 1 && m_real_domain.length(1) == 1) {
// Original domain: (1, 1, nz). Sub domain: (nz, 1, 1)
m_r2c_sub->template post_forward_doit<(Depth+1)>
([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp)
{
post_forward(0, 0, i, sp);
});
} else if (m_real_domain.length(0) == 1 && m_real_domain.length(2) == 1) {
// Original domain: (1, ny, 1). Sub domain: (ny, 1, 1)
m_r2c_sub->template post_forward_doit<(Depth+1)>
([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp)
{
post_forward(0, i, 0, sp);
});
} else if (m_real_domain.length(0) == 1) {
// Original domain: (1, ny, nz). Sub domain: (ny, nz, 1)
m_r2c_sub->template post_forward_doit<(Depth+1)>
([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp)
{
post_forward(0, i, j, sp);
});
} else if (m_real_domain.length(1) == 1) {
// Original domain: (nx, 1, nz). Sub domain: (nx, nz, 1)
m_r2c_sub->template post_forward_doit<(Depth+1)>
([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp)
{
post_forward(i, 0, j, sp);
});
} else {
amrex::Abort("R2c::post_forward_doit: how did this happen?");
}
#endif
} else if (m_real_domain.length(0) == 1) {
// Original domain: (1, ny, nz). Sub domain: (ny, nz, 1)
m_r2c_sub->template post_forward_doit_1
([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp)
{
post_forward(0, i, j, sp);
});
} else if (m_real_domain.length(1) == 1) {
// Original domain: (nx, 1, nz). Sub domain: (nx, nz, 1)
m_r2c_sub->template post_forward_doit_1
([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp)
{
post_forward(i, 0, j, sp);
});
} else {
amrex::Abort("R2C::post_forward_doit: How did this happen?");
amrex::Abort("R2c::post_forward_doit_0: how did this happen?");
}
#endif
#endif
} else {
this->template post_forward_doit_0(post_forward);
}
}

template <typename T, Direction D, DomainStrategy S>
template <typename F>
void R2C<T,D,S>::post_forward_doit_1 (F const& post_forward)
{
if (m_info.batch_mode) {
amrex::Abort("xxxxx todo: post_forward");
} else if (m_r2c_sub) {
amrex::Abort("R2C::post_forward_doit_1: How did this happen?");
} else {
if ( ! m_cz.empty()) {
auto* spectral_fab = detail::get_fab(m_cz);
Expand Down
119 changes: 65 additions & 54 deletions Src/FFT/AMReX_FFT_R2X.H
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ public:
void post_forward_doit (FAB* fab, F const& f);

// private function made public for cuda
template <int Depth, typename F>
void forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward,
IntVect const& ngout = IntVect(0),
Periodicity const& period = Periodicity::NonPeriodic());
template <typename F>
void forwardThenBackward_doit_0 (MF const& inmf, MF& outmf, F const& post_forward,
IntVect const& ngout = IntVect(0),
Periodicity const& period = Periodicity::NonPeriodic());
template <typename F>
void forwardThenBackward_doit_1 (MF const& inmf, MF& outmf, F const& post_forward,
IntVect const& ngout = IntVect(0),
Periodicity const& period = Periodicity::NonPeriodic());

private:

Expand Down Expand Up @@ -552,67 +556,74 @@ template <typename T>
template <typename F>
void R2X<T>::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward)
{
forwardThenBackward_doit<0>(inmf, outmf, post_forward);
forwardThenBackward_doit_0(inmf, outmf, post_forward);
}

template <typename T>
template <int Depth, typename F>
void R2X<T>::forwardThenBackward_doit (MF const& inmf, MF& outmf,
F const& post_forward,
IntVect const& ngout,
Periodicity const& period)
template <typename F>
void R2X<T>::forwardThenBackward_doit_0 (MF const& inmf, MF& outmf,
F const& post_forward,
IntVect const& ngout,
Periodicity const& period)
{
BL_PROFILE("FFT::R2X::forwardbackward");
BL_PROFILE("FFT::R2X::forwardbackward_0");

if (m_r2x_sub) {
#if defined(AMREX_USE_CUDA) && defined(_WIN32)
if (Depth == 0)
#else
if constexpr (Depth == 0)
#endif
{
bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect());
MF inmf_sub, inmf_tmp;
if (inmf_safe) {
inmf_sub = m_sub_helper.make_alias_mf(inmf);
} else {
inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0);
inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0));
inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp);
}

bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect());
MF outmf_sub, outmf_tmp;
if (outmf_safe) {
outmf_sub = m_sub_helper.make_alias_mf(outmf);
} else {
IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect());
outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp);
outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp);
}
bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect());
MF inmf_sub, inmf_tmp;
if (inmf_safe) {
inmf_sub = m_sub_helper.make_alias_mf(inmf);
} else {
inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0);
inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0));
inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp);
}

IntVect const& subngout = m_sub_helper.make_iv(ngout);
Periodicity const& subperiod = m_sub_helper.make_periodicity(period);
GpuArray<int,3> const& order = m_sub_helper.xyz_order();
m_r2x_sub->template forwardThenBackward_doit<(Depth+1)>
(inmf_sub, outmf_sub,
[=] AMREX_GPU_DEVICE (int i, int j, int k, auto& sp)
{
GpuArray<int,3> idx{i,j,k};
post_forward(idx[order[0]], idx[order[1]], idx[order[2]], sp);
},
subngout, subperiod);

if (!outmf_safe) {
outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect());
}
bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect());
MF outmf_sub, outmf_tmp;
if (outmf_safe) {
outmf_sub = m_sub_helper.make_alias_mf(outmf);
} else {
IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect());
outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp);
outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp);
}
else
{
amrex::Abort("R2X::forwardThenBackward_doit: How did this happen?");

IntVect const& subngout = m_sub_helper.make_iv(ngout);
Periodicity const& subperiod = m_sub_helper.make_periodicity(period);
GpuArray<int,3> const& order = m_sub_helper.xyz_order();
m_r2x_sub->template forwardThenBackward_doit_1
(inmf_sub, outmf_sub,
[=] AMREX_GPU_DEVICE (int i, int j, int k, auto& sp)
{
GpuArray<int,3> idx{i,j,k};
post_forward(idx[order[0]], idx[order[1]], idx[order[2]], sp);
},
subngout, subperiod);

if (!outmf_safe) {
outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect());
}
}
else
{
this->template forwardThenBackward_doit_1(inmf, outmf, post_forward, ngout, period);
}
}

template <typename T>
template <typename F>
void R2X<T>::forwardThenBackward_doit_1 (MF const& inmf, MF& outmf,
F const& post_forward,
IntVect const& ngout,
Periodicity const& period)
{
BL_PROFILE("FFT::R2X::forwardbackward_1");

if (m_r2x_sub) {
amrex::Abort("R2X::forwardThenBackward_doit_1: How did this happen?");
}
else
{
this->forward(inmf);

Expand Down
4 changes: 4 additions & 0 deletions Tests/LinearSolvers/ABecLap_SP/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
if (WIN32 AND (AMReX_GPU_BACKEND STREQUAL "CUDA"))
return()
endif()

foreach(D IN LISTS AMReX_SPACEDIM)
if (D EQUAL 1)
continue()
Expand Down

0 comments on commit ebf3d64

Please sign in to comment.