From fecb5e7fb39b80cc820d9265704cadea2b00de0f Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Fri, 13 Dec 2024 09:49:02 -0800 Subject: [PATCH] Make FFT safe for slabs Support FFT on domains that have one cell in some dimensions. It also supports Poisson solves on slab domains. However, for FFT::PoissonHybrid that treats the z-direction in a special way, the z-direction must have more than one cell. --- Src/Base/AMReX_Periodicity.H | 2 + Src/FFT/AMReX_FFT.cpp | 231 ++++++++++++++++++++ Src/FFT/AMReX_FFT_Helper.H | 79 +++++++ Src/FFT/AMReX_FFT_Poisson.H | 76 +++++-- Src/FFT/AMReX_FFT_R2C.H | 261 ++++++++++++++++++---- Src/FFT/AMReX_FFT_R2X.H | 409 +++++++++++++++++++++++++++-------- Tests/FFT/Poisson/main.cpp | 23 +- 7 files changed, 926 insertions(+), 155 deletions(-) diff --git a/Src/Base/AMReX_Periodicity.H b/Src/Base/AMReX_Periodicity.H index d16e175a4f..2618e37ef2 100644 --- a/Src/Base/AMReX_Periodicity.H +++ b/Src/Base/AMReX_Periodicity.H @@ -32,6 +32,8 @@ public: //! Cell-centered domain Box "infinitely" long in non-periodic directions. [[nodiscard]] Box Domain () const noexcept; + [[nodiscard]] IntVect const& intVect () const { return period; } + [[nodiscard]] std::vector shiftIntVect (IntVect const& nghost = IntVect(0)) const; static const Periodicity& NonPeriodic () noexcept; diff --git a/Src/FFT/AMReX_FFT.cpp b/Src/FFT/AMReX_FFT.cpp index 91ac1a7a92..20a4f1ad06 100644 --- a/Src/FFT/AMReX_FFT.cpp +++ b/Src/FFT/AMReX_FFT.cpp @@ -118,4 +118,235 @@ void hip_execute (rocfft_plan plan, void **in, void **out) } #endif +SubHelper::SubHelper (Box const& domain) +{ +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(domain); +#elif (AMREX_SPACEDIM == 2) + if (domain.length(0) == 1) { + m_case = case_1n; + } +#else + if (domain.length(0) == 1 && domain.length(1) == 1) { + m_case = case_11n; + } else if (domain.length(0) == 1 && domain.length(2) == 1) { + m_case = case_1n1; + } else if (domain.length(0) == 1) { + m_case = case_1nn; + } else if (domain.length(1) == 1) { + m_case = case_n1n; + } +#endif +} + +Box SubHelper::make_box (Box const& box) const +{ + return Box(make_iv(box.smallEnd()), make_iv(box.bigEnd()), box.ixType()); +} + +Periodicity SubHelper::make_periodicity (Periodicity const& period) const +{ + return Periodicity(make_iv(period.intVect())); +} + +bool SubHelper::ghost_safe (IntVect const& ng) const +{ +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(ng,this); + return true; +#elif (AMREX_SPACEDIM == 2) + if (m_case == case_1n) { + return (ng[0] == 0); + } else { + return true; + } +#else + if (m_case == case_11n) { + return (ng[0] == 0) && (ng[1] == 0); + } else if (m_case == case_1n1) { + return (ng[0] == 0); + } else if (m_case == case_1nn) { + return (ng[0] == 0); + } else if (m_case == case_n1n) { + return (ng[1] == 0); + } else { + return true; + } +#endif +} + +IntVect SubHelper::make_iv (IntVect const& iv) const +{ + return this->make_array(iv); +} + +IntVect SubHelper::make_safe_ghost (IntVect const& ng) const +{ +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(this); + return ng; +#elif (AMREX_SPACEDIM == 2) + if (m_case == case_1n) { + return IntVect{0,ng[1]}; + } else { + return ng; + } +#else + if (m_case == case_11n) { + return IntVect{0,0,ng[2]}; + } else if (m_case == case_1n1) { + return IntVect{0,ng[1],ng[2]}; + } else if (m_case == case_1nn) { + return IntVect{0,ng[1],ng[2]}; + } else if (m_case == case_n1n) { + return IntVect{ng[0],0,ng[2]}; + } else { + return ng; + } +#endif +} + +BoxArray SubHelper::inverse_boxarray (BoxArray const& ba) const +{ // sub domain order -> original domain order +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(this); + return ba; +#elif (AMREX_SPACEDIM == 2) + AMREX_ALWAYS_ASSERT(m_case == case_1n); + BoxList bl = ba.boxList(); + // sub domain order: y, x + for (auto& b : bl) { + auto const& lo = b.smallEnd(); + auto const& hi = b.bigEnd(); + b.setSmall(IntVect(lo[1],lo[0])); + b.setBig (IntVect(hi[1],hi[0])); + } + return BoxArray(std::move(bl)); +#else + BoxList bl = ba.boxList(); + if (m_case == case_11n) { + // sub domain order: z, x, y + for (auto& b : bl) { + auto const& lo = b.smallEnd(); + auto const& hi = b.bigEnd(); + b.setSmall(IntVect(lo[1],lo[2],lo[0])); + b.setBig (IntVect(hi[1],hi[2],hi[0])); + } + } else if (m_case == case_1n1) { + // sub domain order: y, x, z + for (auto& b : bl) { + auto const& lo = b.smallEnd(); + auto const& hi = b.bigEnd(); + b.setSmall(IntVect(lo[1],lo[0],lo[2])); + b.setBig (IntVect(hi[1],hi[0],hi[2])); + } + } else if (m_case == case_1nn) { + // sub domain order: y, z, x + for (auto& b : bl) { + auto const& lo = b.smallEnd(); + auto const& hi = b.bigEnd(); + b.setSmall(IntVect(lo[2],lo[0],lo[1])); + b.setBig (IntVect(hi[2],hi[0],hi[1])); + } + } else if (m_case == case_n1n) { + // sub domain order: x, z, y + for (auto& b : bl) { + auto const& lo = b.smallEnd(); + auto const& hi = b.bigEnd(); + b.setSmall(IntVect(lo[0],lo[2],lo[1])); + b.setBig (IntVect(hi[0],hi[2],hi[1])); + } + } else { + amrex::Abort("SubHelper::inverse_boxarray: how did this happen?"); + } + return BoxArray(std::move(bl)); +#endif +} + +IntVect SubHelper::inverse_order (IntVect const& order) const +{ +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(this); + return order; +#elif (AMREX_SPACEDIM == 2) + amrex::ignore_unused(this); + return IntVect(order[1],order[0]); +#else + auto translate = [&] (int index) -> int + { + int r = index; + if (m_case == case_11n) { + // sub domain order: z, x, y + if (index == 0) { + r = 2; + } else if (index == 1) { + r = 0; + } else { + r = 1; + } + } else if (m_case == case_1n1) { + // sub domain order: y, x, z + if (index == 0) { + r = 1; + } else if (index == 1) { + r = 0; + } else { + r = 2; + } + } else if (m_case == case_1nn) { + // sub domain order: y, z, x + if (index == 0) { + r = 1; + } else if (index == 1) { + r = 2; + } else { + r = 0; + } + } else if (m_case == case_n1n) { + // sub domain order: x, z, y + if (index == 0) { + r = 0; + } else if (index == 1) { + r = 2; + } else { + r = 1; + } + } + return r; + }; + + IntVect iv; + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { + iv[idim] = translate(order[idim]); + } + return iv; +#endif +} + +GpuArray SubHelper::xyz_order () const +{ +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(this); + return GpuArray{0,1,2}; +#elif (AMREX_SPACEDIM == 2) + if (m_case == case_1n) { + return GpuArray{1,0,2}; + } else { + return GpuArray{0,1,2}; + } +#else + if (m_case == case_11n) { + return GpuArray{1,2,0}; + } else if (m_case == case_1n1) { + return GpuArray{1,0,2}; + } else if (m_case == case_1nn) { + return GpuArray{2,0,1}; + } else if (m_case == case_n1n) { + return GpuArray{0,2,1}; + } else { + return GpuArray{0,1,2}; + } +#endif +} + } diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index 369932f267..a0783dfac5 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -7,9 +7,11 @@ #include #include #include +#include #include #include #include +#include #if defined(AMREX_USE_CUDA) # include @@ -1447,6 +1449,83 @@ struct RotateBwd } }; +namespace detail +{ + struct SubHelper + { + explicit SubHelper (Box const& domain); + + [[nodiscard]] Box make_box (Box const& box) const; + + [[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const; + + [[nodiscard]] bool ghost_safe (IntVect const& ng) const; + + // This rearranges the order. + [[nodiscard]] IntVect make_iv (IntVect const& iv) const; + + // This keeps the order, but zero out the values in the hidden dimension. + [[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const; + + [[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const; + + [[nodiscard]] IntVect inverse_order (IntVect const& order) const; + + template + [[nodiscard]] T make_array (T const& a) const + { +#if (AMREX_SPACEDIM == 1) + amrex::ignore_unused(this); + return a; +#elif (AMREX_SPACEDIM == 2) + if (m_case == case_1n) { + return T{a[1],a[0]}; + } else { + return a; + } +#else + if (m_case == case_11n) { + return T{a[2],a[0],a[1]}; + } else if (m_case == case_1n1) { + return T{a[1],a[0],a[2]}; + } else if (m_case == case_1nn) { + return T{a[1],a[2],a[0]}; + } else if (m_case == case_n1n) { + return T{a[0],a[2],a[1]}; + } else { + return a; + } +#endif + } + + [[nodiscard]] GpuArray xyz_order () const; + + template + FA make_alias_mf (FA const& mf) + { + BoxList bl = mf.boxArray().boxList(); + for (auto& b : bl) { + b = make_box(b); + } + auto const& ng = make_iv(mf.nGrowVect()); + FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false)); + using FAB = typename FA::fab_type; + for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) { + submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr())); + } + return submf; + } + +#if (AMREX_SPACEDIM == 2) + enum Case { case_1n, case_other }; + int m_case = case_other; +#elif (AMREX_SPACEDIM == 3) + enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other }; + int m_case = case_other; +#endif + }; +} + } #endif diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H index 49087016e1..815453f686 100644 --- a/Src/FFT/AMReX_FFT_Poisson.H +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -112,13 +112,19 @@ public: Array,AMREX_SPACEDIM> const& bc) : m_geom(geom), m_bc(bc) { +#if (AMREX_SPACEDIM < 3) + amrex::Abort("FFT::PoissonHybrid: 1D & 2D todo"); + return; +#endif bool periodic_xy = true; for (int idim = 0; idim < 2; ++idim) { - periodic_xy = periodic_xy && (bc[idim].first == Boundary::periodic); - AMREX_ALWAYS_ASSERT((bc[idim].first == Boundary::periodic && - bc[idim].second == Boundary::periodic) || - (bc[idim].first != Boundary::periodic && - bc[idim].second != Boundary::periodic)); + if (m_geom.Domain().length(idim) > 1) { + periodic_xy = periodic_xy && (bc[idim].first == Boundary::periodic); + AMREX_ALWAYS_ASSERT((bc[idim].first == Boundary::periodic && + bc[idim].second == Boundary::periodic) || + (bc[idim].first != Boundary::periodic && + bc[idim].second != Boundary::periodic)); + } } Info info{}; info.setBatchMode(true); @@ -145,6 +151,7 @@ public: AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1)); #else amrex::Abort("FFT::PoissonHybrid: 1D & 2D todo"); + return; #endif build_spmf(); } @@ -204,6 +211,11 @@ void Poisson::solve (MF& soln, MF const& rhs) {AMREX_D_DECL(T(2)/T(m_geom.CellSize(0)*m_geom.CellSize(0)), T(2)/T(m_geom.CellSize(1)*m_geom.CellSize(1)), T(2)/T(m_geom.CellSize(2)*m_geom.CellSize(2)))}; + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { + if (m_geom.Domain().length(idim) == 1) { + dxfac[idim] = 0; + } + } auto scale = (m_r2x) ? m_r2x->scalingFactor() : m_r2c->scalingFactor(); GpuArray offset{AMREX_D_DECL(T(0),T(0),T(0))}; @@ -240,11 +252,11 @@ void Poisson::solve (MF& soln, MF const& rhs) IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1)); if (m_r2x) { - m_r2x->forwardThenBackward_doit(rhs, soln, f, ng, m_geom.periodicity()); + m_r2x->template forwardThenBackward_doit<0>(rhs, soln, f, ng, m_geom.periodicity()); detail::fill_physbc(soln, m_geom, m_bc); } else { m_r2c->forward(rhs); - m_r2c->post_forward_doit(f); + m_r2c->template post_forward_doit<0>(f); m_r2c->backward_doit(soln, ng, m_geom.periodicity()); } } @@ -347,14 +359,24 @@ PoissonHybrid::getSpectralDataLayout () const template void PoissonHybrid::build_spmf () { +#if (AMREX_SPACEDIM == 3) + AMREX_ALWAYS_ASSERT(m_geom.Domain().length(2) > 1 && + (m_geom.Domain().length(0) > 1 || + m_geom.Domain().length(1) > 1)); + if (m_r2c) { Box cdomain = m_geom.Domain(); - cdomain.setBig(0,cdomain.length(0)/2); + if (cdomain.length(0) > 1) { + cdomain.setBig(0,cdomain.length(0)/2); + } else { + cdomain.setBig(1,cdomain.length(1)/2); + } auto cba = amrex::decompose(cdomain, ParallelContext::NProcsSub(), {AMREX_D_DECL(true,true,false)}); DistributionMapping dm = detail::make_iota_distromap(cba.size()); m_spmf_c.define(cba, dm, 1, 0); - } else { + } else if (m_geom.Domain().length(0) > 1 && + m_geom.Domain().length(1) > 1) { if (m_r2x->m_cy.empty()) { // spectral data is real auto sba = amrex::decompose(m_geom.Domain(),ParallelContext::NProcsSub(), {AMREX_D_DECL(true,true,false)}); @@ -372,7 +394,16 @@ void PoissonHybrid::build_spmf () DistributionMapping dm = detail::make_iota_distromap(cba.size()); m_spmf_c.define(cba, dm, 1, 0); } + } else { + // spectral data is real + auto sba = amrex::decompose(m_geom.Domain(),ParallelContext::NProcsSub(), + {AMREX_D_DECL(true,true,false)}); + DistributionMapping dm = detail::make_iota_distromap(sba.size()); + m_spmf_r.define(sba, dm, 1, 0); } +#else + amrex::ignore_unused(this); +#endif } template @@ -465,19 +496,24 @@ void PoissonHybrid::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric) auto dyfac = T(2)/T(m_geom.CellSize(1)*m_geom.CellSize(1)); auto scale = (m_r2x) ? m_r2x->scalingFactor() : m_r2c->scalingFactor(); + if (m_geom.Domain().length(0) == 1) { dxfac = 0; } + if (m_geom.Domain().length(1) == 1) { dyfac = 0; } + GpuArray offset{T(0),T(0)}; for (int idim = 0; idim < AMREX_SPACEDIM-1; ++idim) { - if (m_bc[idim].first == Boundary::odd && - m_bc[idim].second == Boundary::odd) - { - offset[idim] = T(1); - } - else if ((m_bc[idim].first == Boundary::odd && - m_bc[idim].second == Boundary::even) || - (m_bc[idim].first == Boundary::even && - m_bc[idim].second == Boundary::odd)) - { - offset[idim] = T(0.5); + if (m_geom.Domain().length(idim) > 1) { + if (m_bc[idim].first == Boundary::odd && + m_bc[idim].second == Boundary::odd) + { + offset[idim] = T(1); + } + else if ((m_bc[idim].first == Boundary::odd && + m_bc[idim].second == Boundary::even) || + (m_bc[idim].first == Boundary::even && + m_bc[idim].second == Boundary::odd)) + { + offset[idim] = T(0.5); + } } } diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index e1adea156a..75ced825dd 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -30,7 +30,7 @@ template class PoissonHybrid; * https://amrex-codes.github.io/amrex/docs_html/FFT_Chapter.html. */ template + FFT::DomainStrategy S = FFT::DomainStrategy::slab> // Don't change the default. Otherwise OpenBCSolver might break. class R2C { @@ -83,7 +83,7 @@ public: { BL_PROFILE("FFT::R2C::forwardbackward"); this->forward(inmf); - this->post_forward_doit(post_forward); + this->post_forward_doit<0>(post_forward); this->backward(outmf); } @@ -165,7 +165,7 @@ public: [[nodiscard]] std::pair getSpectralDataLayout () const; // This is a private function, but it's public for cuda. - template + template void post_forward_doit (F const& post_forward); private: @@ -221,6 +221,9 @@ private: Box m_spectral_domain_y; Box m_spectral_domain_z; + std::unique_ptr> m_r2c_sub; + detail::SubHelper m_sub_helper; + Info m_info; bool m_do_alld_fft = false; @@ -247,24 +250,42 @@ R2C::R2C (Box const& domain, Info const& info) domain.ixType()), #endif #endif + m_sub_helper(domain), m_info(info) { BL_PROFILE("FFT::R2C"); static_assert(std::is_same_v || std::is_same_v); - AMREX_ALWAYS_ASSERT(m_real_domain.length(0) > 1); -#if (AMREX_SPACEDIM == 3) - AMREX_ALWAYS_ASSERT(m_real_domain.length(1) > 1 || m_real_domain.length(2) == 1); + + AMREX_ALWAYS_ASSERT(m_real_domain.numPts() > 1); +#if (AMREX_SPACEDIM == 2) + AMREX_ALWAYS_ASSERT(!m_info.batch_mode); #else - AMREX_ALWAYS_ASSERT(! m_info.batch_mode); + if (m_info.batch_mode) { + AMREX_ALWAYS_ASSERT((int(domain.length(0) > 1) + + int(domain.length(1) > 1) + + int(domain.length(2) > 1)) >= 2); + } #endif + { + Box subbox = m_sub_helper.make_box(m_real_domain); + if (subbox.size() != m_real_domain.size()) { + m_r2c_sub = std::make_unique>(subbox, info); + return; + } + } + int myproc = ParallelContext::MyProcSub(); int nprocs = std::min(ParallelContext::NProcsSub(), m_info.nprocs); #if (AMREX_SPACEDIM == 3) if (S == DomainStrategy::slab && (m_real_domain.length(1) > 1)) { - m_slab_decomp = true; + if (m_info.batch_mode && m_real_domain.length(2) == 1) { + m_slab_decomp = false; + } else { + m_slab_decomp = true; + } } #endif @@ -292,8 +313,14 @@ R2C::R2C (Box const& domain, Info const& info) // #if (AMREX_SPACEDIM >= 2) +#if (AMREX_SPACEDIM == 2) + bool batch_on_y = false; +#else + bool batch_on_y = m_info.batch_mode && (m_real_domain.length(2) == 1); +#endif DistributionMapping cdmy; - if ((m_real_domain.length(1) > 1) && !m_slab_decomp) { + if ((m_real_domain.length(1) > 1) && !m_slab_decomp && !batch_on_y) + { auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {AMREX_D_DECL(false,true,true)}, true); if (cbay.size() == dmx.size()) { @@ -441,6 +468,8 @@ R2C::~R2C () template void R2C::prepare_openbc () { + if (m_r2c_sub) { amrex::Abort("R2C: OpenBC not supported with reduced dimensions"); } + #if (AMREX_SPACEDIM == 3) if (m_do_alld_fft) { return; } @@ -497,6 +526,17 @@ void R2C::forward (MF const& inmf) { BL_PROFILE("FFT::R2C::forward(in)"); + if (m_r2c_sub) { + if (m_sub_helper.ghost_safe(inmf.nGrowVect())) { + m_r2c_sub->forward(m_sub_helper.make_alias_mf(inmf)); + } else { + MF tmp(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + m_r2c_sub->forward(m_sub_helper.make_alias_mf(tmp)); + } + return; + } + if (&m_rx != &inmf) { m_rx.ParallelCopy(inmf, 0, 0, 1); } @@ -552,9 +592,25 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout, { BL_PROFILE("FFT::R2C::backward(out)"); + if (m_r2c_sub) { + if (m_sub_helper.ghost_safe(outmf.nGrowVect())) { + MF submf = m_sub_helper.make_alias_mf(outmf); + IntVect const& subngout = m_sub_helper.make_iv(ngout); + Periodicity const& subperiod = m_sub_helper.make_periodicity(period); + m_r2c_sub->backward_doit(submf, subngout, subperiod); + } else { + MF tmp(outmf.boxArray(), outmf.DistributionMap(), 1, + m_sub_helper.make_safe_ghost(outmf.nGrowVect())); + this->backward_doit(tmp, ngout, period); + outmf.LocalCopy(tmp, 0, 0, 1, tmp.nGrowVect()); + } + return; + } + if (m_do_alld_fft) { m_fft_bwd_x.template compute_r2c(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); return; } @@ -576,7 +632,8 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout, auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x; fft_x.template compute_r2c(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); } template @@ -608,11 +665,59 @@ R2C::make_c2c_plans (cMF& inout) } template -template +template void R2C::post_forward_doit (F const& post_forward) { if (m_info.batch_mode) { amrex::Abort("xxxxx todo: post_forward"); +#if (AMREX_SPACEDIM > 1) + } else if (m_r2c_sub) { + if constexpr (Depth == 0) { + // We need to pass the originally ordered indices to post_forward. +#if (AMREX_SPACEDIM == 2) + // The original domain is (1,ny). The sub domain is (ny,1). + m_r2c_sub->template post_forward_doit<(Depth+1)> + ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) + { + post_forward(0, i, 0, sp); + }); +#else + if (m_real_domain.length(0) == 1 && m_real_domain.length(1) == 1) { + // Original domain: (1, 1, nz). Sub domain: (nz, 1, 1) + m_r2c_sub->template post_forward_doit<(Depth+1)> + ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) + { + post_forward(0, 0, i, sp); + }); + } else if (m_real_domain.length(0) == 1 && m_real_domain.length(2) == 1) { + // Original domain: (1, ny, 1). Sub domain: (ny, 1, 1) + m_r2c_sub->template post_forward_doit<(Depth+1)> + ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) + { + post_forward(0, i, 0, sp); + }); + } else if (m_real_domain.length(0) == 1) { + // Original domain: (1, ny, nz). Sub domain: (ny, nz, 1) + m_r2c_sub->template post_forward_doit<(Depth+1)> + ([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp) + { + post_forward(0, i, j, sp); + }); + } else if (m_real_domain.length(1) == 1) { + // Original domain: (nx, 1, nz). Sub domain: (nx, nz, 1) + m_r2c_sub->template post_forward_doit<(Depth+1)> + ([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp) + { + post_forward(i, 0, j, sp); + }); + } else { + amrex::Abort("R2c::post_forward_doit: how did this happen?"); + } +#endif + } else { + amrex::Abort("R2C::post_forward_doit: How did this happen?"); + } +#endif } else { if ( ! m_cz.empty()) { auto* spectral_fab = detail::get_fab(m_cz); @@ -653,8 +758,12 @@ T R2C::scalingFactor () const { #if (AMREX_SPACEDIM == 3) if (m_info.batch_mode) { - return T(1)/T(Long(m_real_domain.length(0)) * - Long(m_real_domain.length(1))); + if (m_real_domain.length(2) > 1) { + return T(1)/T(Long(m_real_domain.length(0)) * + Long(m_real_domain.length(1))); + } else { + return T(1)/T(m_real_domain.length(0)); + } } else #endif { @@ -668,6 +777,12 @@ template ::cMF *, IntVect> R2C::getSpectralData () { +#if (AMREX_SPACEDIM > 1) + if (m_r2c_sub) { + auto [cmf, order] = m_r2c_sub->getSpectralData(); + return std::make_pair(cmf, m_sub_helper.inverse_order(order)); + } else +#endif if (!m_cz.empty()) { return std::make_pair(&m_cz, IntVect{AMREX_D_DECL(2,0,1)}); } else if (!m_cy.empty()) { @@ -684,18 +799,48 @@ void R2C::forward (MF const& inmf, cMF& outmf) { BL_PROFILE("FFT::R2C::forward(inout)"); - forward(inmf); - if (!m_cz.empty()) { // m_cz's order (z,x,y) -> (x,y,z) - RotateBwd dtos{}; - MultiBlockCommMetaData cmd - (outmf, m_spectral_domain_x, m_cz, IntVect(0), dtos); - ParallelCopy(outmf, m_cz, cmd, 0, 0, 1, dtos); - } else if (!m_cy.empty()) { // m_cy's order (y,x,z) -> (x,y,z) - MultiBlockCommMetaData cmd - (outmf, m_spectral_domain_x, m_cy, IntVect(0), m_dtos_y2x); - ParallelCopy(outmf, m_cy, cmd, 0, 0, 1, m_dtos_y2x); - } else { - outmf.ParallelCopy(m_cx, 0, 0, 1); + if (m_r2c_sub) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + MF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); + } else { + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } + + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + cMF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); + } else { + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, 0); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + } + + m_r2c_sub->forward(inmf_sub, outmf_sub); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, IntVect(0)); + } + } + else + { + forward(inmf); + if (!m_cz.empty()) { // m_cz's order (z,x,y) -> (x,y,z) + RotateBwd dtos{}; + MultiBlockCommMetaData cmd + (outmf, m_spectral_domain_x, m_cz, IntVect(0), dtos); + ParallelCopy(outmf, m_cz, cmd, 0, 0, 1, dtos); + } else if (!m_cy.empty()) { // m_cy's order (y,x,z) -> (x,y,z) + MultiBlockCommMetaData cmd + (outmf, m_spectral_domain_x, m_cy, IntVect(0), m_dtos_y2x); + ParallelCopy(outmf, m_cy, cmd, 0, 0, 1, m_dtos_y2x); + } else { + outmf.ParallelCopy(m_cx, 0, 0, 1); + } } } @@ -713,25 +858,65 @@ void R2C::backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout { BL_PROFILE("FFT::R2C::backward(inout)"); - if (!m_cz.empty()) { // (x,y,z) -> m_cz's order (z,x,y) - RotateFwd dtos{}; - MultiBlockCommMetaData cmd - (m_cz, m_spectral_domain_z, inmf, IntVect(0), dtos); - ParallelCopy(m_cz, inmf, cmd, 0, 0, 1, dtos); - } else if (!m_cy.empty()) { // (x,y,z) -> m_cy's ordering (y,x,z) - MultiBlockCommMetaData cmd - (m_cy, m_spectral_domain_y, inmf, IntVect(0), m_dtos_x2y); - ParallelCopy(m_cy, inmf, cmd, 0, 0, 1, m_dtos_x2y); - } else { - m_cx.ParallelCopy(inmf, 0, 0, 1); + if (m_r2c_sub) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + cMF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); + } else { + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } + + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + MF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); + } else { + IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + } + + IntVect const& subngout = m_sub_helper.make_iv(ngout); + Periodicity const& subperiod = m_sub_helper.make_periodicity(period); + m_r2c_sub->backward_doit(inmf_sub, outmf_sub, subngout, subperiod); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); + } + } + else + { + if (!m_cz.empty()) { // (x,y,z) -> m_cz's order (z,x,y) + RotateFwd dtos{}; + MultiBlockCommMetaData cmd + (m_cz, m_spectral_domain_z, inmf, IntVect(0), dtos); + ParallelCopy(m_cz, inmf, cmd, 0, 0, 1, dtos); + } else if (!m_cy.empty()) { // (x,y,z) -> m_cy's ordering (y,x,z) + MultiBlockCommMetaData cmd + (m_cy, m_spectral_domain_y, inmf, IntVect(0), m_dtos_x2y); + ParallelCopy(m_cy, inmf, cmd, 0, 0, 1, m_dtos_x2y); + } else { + m_cx.ParallelCopy(inmf, 0, 0, 1); + } + backward_doit(outmf, ngout, period); } - backward_doit(outmf, ngout, period); } template std::pair R2C::getSpectralDataLayout () const { +#if (AMREX_SPACEDIM > 1) + if (m_r2c_sub) { + auto const& [ba, dm] = m_r2c_sub->getSpectralDataLayout(); + return std::make_pair(m_sub_helper.inverse_boxarray(ba), dm); + } +#endif + #if (AMREX_SPACEDIM == 3) if (!m_cz.empty()) { BoxList bl = m_cz.boxArray().boxList(); diff --git a/Src/FFT/AMReX_FFT_R2X.H b/Src/FFT/AMReX_FFT_R2X.H index 19cd8fca12..d86bf5c29b 100644 --- a/Src/FFT/AMReX_FFT_R2X.H +++ b/Src/FFT/AMReX_FFT_R2X.H @@ -51,6 +51,12 @@ public: template void post_forward_doit (FAB* fab, F const& f); + // private function made public for cuda + template + void forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward, + IntVect const& ngout = IntVect(0), + Periodicity const& period = Periodicity::NonPeriodic()); + private: void forward (MF const& inmf, MF& outmf); @@ -62,11 +68,6 @@ private: Periodicity const& period); void backward (); - template - void forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward, - IntVect const& ngout = IntVect(0), - Periodicity const& period = Periodicity::NonPeriodic()); - Box m_dom_0; Array,AMREX_SPACEDIM> m_bc; @@ -109,6 +110,9 @@ private: Box m_dom_cy; Box m_dom_cz; + std::unique_ptr> m_r2x_sub; + detail::SubHelper m_sub_helper; + Info m_info; }; @@ -118,27 +122,40 @@ R2X::R2X (Box const& domain, Info const& info) : m_dom_0(domain), m_bc(bc), + m_sub_helper(domain), m_info(info) { BL_PROFILE("FFT::R2X"); static_assert(std::is_same_v || std::is_same_v); - AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0 && - domain.length(0) > 1 && - domain.cellCentered()); -#if (AMREX_SPACEDIM == 3) - AMREX_ALWAYS_ASSERT(domain.length(1) > 1 || domain.length(2) == 1); + + AMREX_ALWAYS_ASSERT(m_dom_0.numPts() > 1); +#if (AMREX_SPACEDIM == 2) + AMREX_ALWAYS_ASSERT(!m_info.batch_mode); #else - AMREX_ALWAYS_ASSERT(! m_info.batch_mode); + if (m_info.batch_mode) { + AMREX_ALWAYS_ASSERT((int(domain.length(0) > 1) + + int(domain.length(1) > 1) + + int(domain.length(2) > 1)) >= 2); + } #endif + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { - AMREX_ALWAYS_ASSERT(domain.length(idim) > 1); if (bc[idim].first == Boundary::periodic || bc[idim].second == Boundary::periodic) { AMREX_ALWAYS_ASSERT(bc[idim].first == bc[idim].second); } } + { + Box subbox = m_sub_helper.make_box(m_dom_0); + if (subbox.size() != m_dom_0.size()) { + m_r2x_sub = std::make_unique> + (subbox, m_sub_helper.make_array(bc), info); + return; + } + } + int myproc = ParallelContext::MyProcSub(); int nprocs = std::min(ParallelContext::NProcsSub(), m_info.nprocs); @@ -166,7 +183,14 @@ R2X::R2X (Box const& domain, } // else: x-fft: r2r(m_rx) #if (AMREX_SPACEDIM >= 2) - if (domain.length(1) > 1) { + +#if (AMREX_SPACEDIM == 2) + bool batch_on_y = false; +#else + bool batch_on_y = m_info.batch_mode && (m_dom_0.length(2) == 1); +#endif + + if ((domain.length(1) > 1) && !batch_on_y) { if (! m_cx.empty()) { // copy(m_cx->m_cy) m_dom_cy = Box(IntVect(0), IntVect(AMREX_D_DECL(m_dom_cx.bigEnd(1), @@ -308,7 +332,7 @@ R2X::R2X (Box const& domain, // #if (AMREX_SPACEDIM >= 2) - if (domain.length(1) > 1) { + if (!m_cy.empty() || !m_ry.empty()) { if (! m_cx.empty()) { // copy(m_cx->m_cy) m_cmd_cx2cy = std::make_unique @@ -326,7 +350,7 @@ R2X::R2X (Box const& domain, #endif #if (AMREX_SPACEDIM == 3) - if (domain.length(2) > 1 && !m_info.batch_mode) { + if (!m_cz.empty() || !m_rz.empty()) { if (! m_cy.empty()) { // copy(m_cy, m_cz) m_cmd_cy2cz = std::make_unique @@ -512,6 +536,9 @@ T R2X::scalingFactor () const { Long r = 1; int ndims = m_info.batch_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM; +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode && m_dom_0.length(2) == 1) { ndims = 1; }; +#endif for (int idim = 0; idim < ndims; ++idim) { r *= m_dom_0.length(idim); if (m_bc[idim].first != Boundary::periodic && (m_dom_0.length(idim) > 1)) { @@ -525,11 +552,11 @@ template template void R2X::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward) { - forwardThenBackward_doit(inmf, outmf, post_forward); + forwardThenBackward_doit<0>(inmf, outmf, post_forward); } template -template +template void R2X::forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward, IntVect const& ngout, @@ -537,47 +564,95 @@ void R2X::forwardThenBackward_doit (MF const& inmf, MF& outmf, { BL_PROFILE("FFT::R2X::forwardbackward"); - this->forward(inmf); + if (m_r2x_sub) { + if constexpr (Depth == 0) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + MF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); + } else { + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } - // post-forward + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + MF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); + } else { + IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + } - int actual_dim = AMREX_SPACEDIM; + IntVect const& subngout = m_sub_helper.make_iv(ngout); + Periodicity const& subperiod = m_sub_helper.make_periodicity(period); + GpuArray const& order = m_sub_helper.xyz_order(); + m_r2x_sub->template forwardThenBackward_doit<(Depth+1)> + (inmf_sub, outmf_sub, + [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& sp) + { + GpuArray idx{i,j,k}; + post_forward(idx[order[0]], idx[order[1]], idx[order[2]], sp); + }, + subngout, subperiod); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); + } + } + else + { + amrex::Abort("R2X::forwardThenBackward_doit: How did this happen?"); + } + } + else + { + this->forward(inmf); + + // post-forward + + int actual_dim = AMREX_SPACEDIM; #if (AMREX_SPACEDIM >= 2) - if (m_dom_0.length(1) == 1) { actual_dim = 1; } + if (m_dom_0.length(1) == 1) { actual_dim = 1; } #endif #if (AMREX_SPACEDIM == 3) - if ((m_dom_0.length(2) == 1) && (m_dom_0.length(1) > 1)) { actual_dim = 2; } + if ((m_dom_0.length(2) == 1) && (m_dom_0.length(1) > 1)) { actual_dim = 2; } #endif - if (actual_dim == 1) { - if (m_cx.empty()) { - post_forward_doit<0>(detail::get_fab(m_rx), post_forward); - } else { - post_forward_doit<0>(detail::get_fab(m_cx), post_forward); + if (actual_dim == 1) { + if (m_cx.empty()) { + post_forward_doit<0>(detail::get_fab(m_rx), post_forward); + } else { + post_forward_doit<0>(detail::get_fab(m_cx), post_forward); + } } - } #if (AMREX_SPACEDIM >= 2) - else if (actual_dim == 2) { - if (m_cy.empty()) { - post_forward_doit<1>(detail::get_fab(m_ry), post_forward); - } else { - post_forward_doit<1>(detail::get_fab(m_cy), post_forward); + else if (actual_dim == 2) { + if (m_cy.empty()) { + post_forward_doit<1>(detail::get_fab(m_ry), post_forward); + } else { + post_forward_doit<1>(detail::get_fab(m_cy), post_forward); + } } - } #endif #if (AMREX_SPACEDIM == 3) - else if (actual_dim == 3) { - if (m_cz.empty()) { - post_forward_doit<2>(detail::get_fab(m_rz), post_forward); - } else { - post_forward_doit<2>(detail::get_fab(m_cz), post_forward); + else if (actual_dim == 3) { + if (m_cz.empty()) { + post_forward_doit<2>(detail::get_fab(m_rz), post_forward); + } else { + post_forward_doit<2>(detail::get_fab(m_cz), post_forward); + } } - } #endif - this->backward(); + this->backward(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); + } } template @@ -585,6 +660,17 @@ void R2X::forward (MF const& inmf) { BL_PROFILE("FFT::R2X::forward"); + if (m_r2x_sub) { + if (m_sub_helper.ghost_safe(inmf.nGrowVect())) { + m_r2x_sub->forward(m_sub_helper.make_alias_mf(inmf)); + } else { + MF tmp(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + m_r2x_sub->forward(m_sub_helper.make_alias_mf(tmp)); + } + return; + } + m_rx.ParallelCopy(inmf, 0, 0, 1); if (m_bc[0].first == Boundary::periodic) { m_fft_fwd_x.template compute_r2c(); @@ -637,45 +723,109 @@ void R2X::forward (MF const& inmf) template void R2X::forward (MF const& inmf, MF& outmf) { - this->forward(inmf); + if (m_r2x_sub) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + MF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); + } else { + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } -#if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { - if (m_cy.empty()) { - ParallelCopy(outmf, m_dom_rx, m_ry, 0, 0, 1, IntVect(0), Swap01{}); + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + MF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); } else { - amrex::Abort("R2X::forward(MF,MF): How did this happen?"); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, 0); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); } - } else -#endif + + m_r2x_sub->forward(inmf_sub, outmf_sub); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, IntVect(0)); + } + } + else { - amrex::ignore_unused(outmf); - amrex::Abort("R2X::forward(MF,MF): TODO"); + this->forward(inmf); + +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode) { + if (m_cy.empty() && !m_ry.empty()) { + ParallelCopy(outmf, m_dom_rx, m_ry, 0, 0, 1, IntVect(0), Swap01{}); + } else if (m_ry.empty() && m_cy.empty() && m_cx.empty()) { + outmf.ParallelCopy(m_rx, 0, 0, 1); + } else { + amrex::Abort("R2X::forward(MF,MF): How did this happen?"); + } + } else +#endif + { + amrex::ignore_unused(outmf); + amrex::Abort("R2X::forward(MF,MF): TODO"); + } } } template void R2X::forward (MF const& inmf, cMF& outmf) { - this->forward(inmf); + if (m_r2x_sub) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + MF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); + } else { + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } -#if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { - if (!m_cy.empty()) { - auto lo = m_dom_cy.smallEnd(); - auto hi = m_dom_cy.bigEnd(); - std::swap(lo[0],lo[1]); - std::swap(hi[0],hi[1]); - Box dom(lo,hi); - ParallelCopy(outmf, dom, m_cy, 0, 0, 1, IntVect(0), Swap01{}); + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + cMF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); } else { - amrex::Abort("R2X::forward(MF,cMF): How did this happen?"); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, 0); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); } - } else -#endif + + m_r2x_sub->forward(inmf_sub, outmf_sub); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, IntVect(0)); + } + } + else { - amrex::ignore_unused(outmf); - amrex::Abort("R2X::forward(MF,cMF): TODO"); + this->forward(inmf); + +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode) { + if (!m_cy.empty()) { + auto lo = m_dom_cy.smallEnd(); + auto hi = m_dom_cy.bigEnd(); + std::swap(lo[0],lo[1]); + std::swap(hi[0],hi[1]); + Box dom(lo,hi); + ParallelCopy(outmf, dom, m_cy, 0, 0, 1, IntVect(0), Swap01{}); + } else if (m_ry.empty() && m_cy.empty() && !m_cx.empty()) { + outmf.ParallelCopy(m_cx, 0, 0, 1); + } else { + amrex::Abort("R2X::forward(MF,cMF): How did this happen?"); + } + } else +#endif + { + amrex::ignore_unused(outmf); + amrex::Abort("R2X::forward(MF,cMF): TODO"); + } } } @@ -684,6 +834,8 @@ void R2X::backward () { BL_PROFILE("FFT::R2X::backward"); + AMREX_ALWAYS_ASSERT(m_r2x_sub == nullptr); + #if (AMREX_SPACEDIM == 3) if (m_bc[2].first != Boundary::periodic) { @@ -736,52 +888,127 @@ template void R2X::backward (MF const& inmf, MF& outmf, IntVect const& ngout, Periodicity const& period) { -#if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { - if (m_cy.empty()) { - ParallelCopy(m_ry, m_dom_ry, inmf, 0, 0, 1, IntVect(0), Swap01{}); + if (m_r2x_sub) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + MF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); } else { - amrex::Abort("R2X::backward(MF,MF): How did this happen?"); + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } + + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + MF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); + } else { + IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + } + + IntVect const& subngout = m_sub_helper.make_iv(ngout); + Periodicity const& subperiod = m_sub_helper.make_periodicity(period); + m_r2x_sub->backward(inmf_sub, outmf_sub, subngout, subperiod); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); } - } else -#endif - { - amrex::ignore_unused(inmf,outmf,ngout,period); - amrex::Abort("R2X::backward(MF,MF): TODO"); } + else + { +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode) { + if (m_cy.empty() && !m_ry.empty()) { + ParallelCopy(m_ry, m_dom_ry, inmf, 0, 0, 1, IntVect(0), Swap01{}); + } else if (m_ry.empty() && m_cy.empty() && m_cx.empty()) { + m_rx.ParallelCopy(inmf, 0, 0, 1); + } else { + amrex::Abort("R2X::backward(MF,MF): How did this happen?"); + } + } else +#endif + { + amrex::ignore_unused(inmf,outmf,ngout,period); + amrex::Abort("R2X::backward(MF,MF): TODO"); + } - this->backward(); + this->backward(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); + } } template void R2X::backward (cMF const& inmf, MF& outmf, IntVect const& ngout, Periodicity const& period) { -#if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { - if (!m_cy.empty()) { - ParallelCopy(m_cy, m_dom_cy, inmf, 0, 0, 1, IntVect(0), Swap01{}); + if (m_r2x_sub) + { + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + cMF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); } else { - amrex::Abort("R2X::backward(cMF,MF): How did this happen?"); + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } + + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + MF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); + } else { + IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + } + + IntVect const& subngout = m_sub_helper.make_iv(ngout); + Periodicity const& subperiod = m_sub_helper.make_periodicity(period); + m_r2x_sub->backward(inmf_sub, outmf_sub, subngout, subperiod); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); } - } else -#endif - { - amrex::ignore_unused(inmf,outmf,ngout,period); - amrex::Abort("R2X::backward(cMF,MF): TODO"); } + else + { +#if (AMREX_SPACEDIM == 3) + if (m_info.batch_mode) { + if (!m_cy.empty()) { + ParallelCopy(m_cy, m_dom_cy, inmf, 0, 0, 1, IntVect(0), Swap01{}); + } else if (m_ry.empty() && m_cy.empty() && !m_cx.empty()) { + m_cx.ParallelCopy(inmf, 0, 0, 1); + } else { + amrex::Abort("R2X::backward(cMF,MF): How did this happen?"); + } + } else +#endif + { + amrex::ignore_unused(inmf,outmf,ngout,period); + amrex::Abort("R2X::backward(cMF,MF): TODO"); + } - this->backward(); + this->backward(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); + } } template template void R2X::post_forward_doit (FAB* fab, F const& f) { + if (m_info.batch_mode) { + amrex::Abort("xxxxx post_forward_doit: todo"); + } if (fab) { auto const& a = fab->array(); ParallelFor(fab->box(), diff --git a/Tests/FFT/Poisson/main.cpp b/Tests/FFT/Poisson/main.cpp index f344edbd38..fb7c869cd7 100644 --- a/Tests/FFT/Poisson/main.cpp +++ b/Tests/FFT/Poisson/main.cpp @@ -56,10 +56,13 @@ void make_rhs (MultiFab& rhs, Geometry const& geom, }); bool has_dirichlet = false; + auto domlen = geom.Domain().length(); for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { - has_dirichlet = has_dirichlet || - fft_bc[idim].first == FFT::Boundary::odd || - fft_bc[idim].second == FFT::Boundary::odd; + if (domlen[idim] > 1) { + has_dirichlet = has_dirichlet || + fft_bc[idim].first == FFT::Boundary::odd || + fft_bc[idim].second == FFT::Boundary::odd; + } } if (! has_dirichlet) { // Shift rhs so that its sum is zero. @@ -80,15 +83,23 @@ std::pair check_convergence {AMREX_D_DECL(1._rt/(dx[0]*dx[0]), 1._rt/(dx[1]*dx[1]), 1._rt/(dx[2]*dx[2]))}; + auto domlen = geom.Domain().length(); ParallelFor(res, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k) { auto const& phia = phi_ma[b]; - Real lap = (phia(i-1,j,k)-2._rt*phia(i,j,k)+phia(i+1,j,k)) * lapfac[0]; + Real lap = 0; + if (domlen[0] > 1) { + lap += (phia(i-1,j,k)-2._rt*phia(i,j,k)+phia(i+1,j,k)) * lapfac[0]; + } #if (AMREX_SPACEDIM >= 2) - lap += (phia(i,j-1,k)-2._rt*phia(i,j,k)+phia(i,j+1,k)) * lapfac[1]; + if (domlen[1] > 1) { + lap += (phia(i,j-1,k)-2._rt*phia(i,j,k)+phia(i,j+1,k)) * lapfac[1]; + } #endif #if (AMREX_SPACEDIM == 3) - lap += (phia(i,j,k-1)-2._rt*phia(i,j,k)+phia(i,j,k+1)) * lapfac[2]; + if (domlen[2] > 1) { + lap += (phia(i,j,k-1)-2._rt*phia(i,j,k)+phia(i,j,k+1)) * lapfac[2]; + } #endif res_ma[b](i,j,k) = rhs_ma[b](i,j,k) - lap; });