From ebf3d64feb11c89952fc215fa32c020171a1d336 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Thu, 23 Jan 2025 20:51:29 -0800 Subject: [PATCH] more fixes --- Src/FFT/AMReX_FFT_Poisson.H | 4 +- Src/FFT/AMReX_FFT_R2C.H | 102 ++++++++------- Src/FFT/AMReX_FFT_R2X.H | 119 ++++++++++-------- Tests/LinearSolvers/ABecLap_SP/CMakeLists.txt | 4 + 4 files changed, 126 insertions(+), 103 deletions(-) diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H index 815453f6860..b1748ad2ac6 100644 --- a/Src/FFT/AMReX_FFT_Poisson.H +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -252,11 +252,11 @@ void Poisson::solve (MF& soln, MF const& rhs) IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1)); if (m_r2x) { - m_r2x->template forwardThenBackward_doit<0>(rhs, soln, f, ng, m_geom.periodicity()); + m_r2x->template forwardThenBackward_doit_0(rhs, soln, f, ng, m_geom.periodicity()); detail::fill_physbc(soln, m_geom, m_bc); } else { m_r2c->forward(rhs); - m_r2c->template post_forward_doit<0>(f); + m_r2c->template post_forward_doit_0(f); m_r2c->backward_doit(soln, ng, m_geom.periodicity()); } } diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index 0d36b3c7493..67b3ca84c75 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -83,7 +83,7 @@ public: { BL_PROFILE("FFT::R2C::forwardbackward"); this->forward(inmf); - this->post_forward_doit<0>(post_forward); + this->post_forward_doit_0(post_forward); this->backward(outmf); } @@ -165,8 +165,11 @@ public: [[nodiscard]] std::pair getSpectralDataLayout () const; // This is a private function, but it's public for cuda. - template - void post_forward_doit (F const& post_forward); + template + void post_forward_doit_0 (F const& post_forward); + + template + void post_forward_doit_1 (F const& post_forward); private: @@ -665,63 +668,68 @@ R2C::make_c2c_plans (cMF& inout) } template -template -void R2C::post_forward_doit (F const& post_forward) +template +void R2C::post_forward_doit_0 (F const& post_forward) { if (m_info.batch_mode) { amrex::Abort("xxxxx todo: post_forward"); #if (AMREX_SPACEDIM > 1) } else if (m_r2c_sub) { -#if defined(AMREX_USE_CUDA) && defined(_WIN32) - if (Depth == 0) { -#else - if constexpr (Depth == 0) { -#endif - // We need to pass the originally ordered indices to post_forward. + // We need to pass the originally ordered indices to post_forward. #if (AMREX_SPACEDIM == 2) - // The original domain is (1,ny). The sub domain is (ny,1). - m_r2c_sub->template post_forward_doit<(Depth+1)> + // The original domain is (1,ny). The sub domain is (ny,1). + m_r2c_sub->template post_forward_doit_1 + ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) + { + post_forward(0, i, 0, sp); + }); +#else + if (m_real_domain.length(0) == 1 && m_real_domain.length(1) == 1) { + // Original domain: (1, 1, nz). Sub domain: (nz, 1, 1) + m_r2c_sub->template post_forward_doit_1 + ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) + { + post_forward(0, 0, i, sp); + }); + } else if (m_real_domain.length(0) == 1 && m_real_domain.length(2) == 1) { + // Original domain: (1, ny, 1). Sub domain: (ny, 1, 1) + m_r2c_sub->template post_forward_doit_1 ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) { post_forward(0, i, 0, sp); }); -#else - if (m_real_domain.length(0) == 1 && m_real_domain.length(1) == 1) { - // Original domain: (1, 1, nz). Sub domain: (nz, 1, 1) - m_r2c_sub->template post_forward_doit<(Depth+1)> - ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) - { - post_forward(0, 0, i, sp); - }); - } else if (m_real_domain.length(0) == 1 && m_real_domain.length(2) == 1) { - // Original domain: (1, ny, 1). Sub domain: (ny, 1, 1) - m_r2c_sub->template post_forward_doit<(Depth+1)> - ([=] AMREX_GPU_DEVICE (int i, int, int, auto& sp) - { - post_forward(0, i, 0, sp); - }); - } else if (m_real_domain.length(0) == 1) { - // Original domain: (1, ny, nz). Sub domain: (ny, nz, 1) - m_r2c_sub->template post_forward_doit<(Depth+1)> - ([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp) - { - post_forward(0, i, j, sp); - }); - } else if (m_real_domain.length(1) == 1) { - // Original domain: (nx, 1, nz). Sub domain: (nx, nz, 1) - m_r2c_sub->template post_forward_doit<(Depth+1)> - ([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp) - { - post_forward(i, 0, j, sp); - }); - } else { - amrex::Abort("R2c::post_forward_doit: how did this happen?"); - } -#endif + } else if (m_real_domain.length(0) == 1) { + // Original domain: (1, ny, nz). Sub domain: (ny, nz, 1) + m_r2c_sub->template post_forward_doit_1 + ([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp) + { + post_forward(0, i, j, sp); + }); + } else if (m_real_domain.length(1) == 1) { + // Original domain: (nx, 1, nz). Sub domain: (nx, nz, 1) + m_r2c_sub->template post_forward_doit_1 + ([=] AMREX_GPU_DEVICE (int i, int j, int, auto& sp) + { + post_forward(i, 0, j, sp); + }); } else { - amrex::Abort("R2C::post_forward_doit: How did this happen?"); + amrex::Abort("R2c::post_forward_doit_0: how did this happen?"); } #endif +#endif + } else { + this->template post_forward_doit_0(post_forward); + } +} + +template +template +void R2C::post_forward_doit_1 (F const& post_forward) +{ + if (m_info.batch_mode) { + amrex::Abort("xxxxx todo: post_forward"); + } else if (m_r2c_sub) { + amrex::Abort("R2C::post_forward_doit_1: How did this happen?"); } else { if ( ! m_cz.empty()) { auto* spectral_fab = detail::get_fab(m_cz); diff --git a/Src/FFT/AMReX_FFT_R2X.H b/Src/FFT/AMReX_FFT_R2X.H index 19326323a74..6e383a52144 100644 --- a/Src/FFT/AMReX_FFT_R2X.H +++ b/Src/FFT/AMReX_FFT_R2X.H @@ -52,10 +52,14 @@ public: void post_forward_doit (FAB* fab, F const& f); // private function made public for cuda - template - void forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward, - IntVect const& ngout = IntVect(0), - Periodicity const& period = Periodicity::NonPeriodic()); + template + void forwardThenBackward_doit_0 (MF const& inmf, MF& outmf, F const& post_forward, + IntVect const& ngout = IntVect(0), + Periodicity const& period = Periodicity::NonPeriodic()); + template + void forwardThenBackward_doit_1 (MF const& inmf, MF& outmf, F const& post_forward, + IntVect const& ngout = IntVect(0), + Periodicity const& period = Periodicity::NonPeriodic()); private: @@ -552,67 +556,74 @@ template template void R2X::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward) { - forwardThenBackward_doit<0>(inmf, outmf, post_forward); + forwardThenBackward_doit_0(inmf, outmf, post_forward); } template -template -void R2X::forwardThenBackward_doit (MF const& inmf, MF& outmf, - F const& post_forward, - IntVect const& ngout, - Periodicity const& period) +template +void R2X::forwardThenBackward_doit_0 (MF const& inmf, MF& outmf, + F const& post_forward, + IntVect const& ngout, + Periodicity const& period) { - BL_PROFILE("FFT::R2X::forwardbackward"); + BL_PROFILE("FFT::R2X::forwardbackward_0"); if (m_r2x_sub) { -#if defined(AMREX_USE_CUDA) && defined(_WIN32) - if (Depth == 0) -#else - if constexpr (Depth == 0) -#endif - { - bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); - MF inmf_sub, inmf_tmp; - if (inmf_safe) { - inmf_sub = m_sub_helper.make_alias_mf(inmf); - } else { - inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); - inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); - inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); - } - - bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); - MF outmf_sub, outmf_tmp; - if (outmf_safe) { - outmf_sub = m_sub_helper.make_alias_mf(outmf); - } else { - IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); - outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); - outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); - } + bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); + MF inmf_sub, inmf_tmp; + if (inmf_safe) { + inmf_sub = m_sub_helper.make_alias_mf(inmf); + } else { + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); + inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + } - IntVect const& subngout = m_sub_helper.make_iv(ngout); - Periodicity const& subperiod = m_sub_helper.make_periodicity(period); - GpuArray const& order = m_sub_helper.xyz_order(); - m_r2x_sub->template forwardThenBackward_doit<(Depth+1)> - (inmf_sub, outmf_sub, - [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& sp) - { - GpuArray idx{i,j,k}; - post_forward(idx[order[0]], idx[order[1]], idx[order[2]], sp); - }, - subngout, subperiod); - - if (!outmf_safe) { - outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); - } + bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); + MF outmf_sub, outmf_tmp; + if (outmf_safe) { + outmf_sub = m_sub_helper.make_alias_mf(outmf); + } else { + IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); + outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); } - else - { - amrex::Abort("R2X::forwardThenBackward_doit: How did this happen?"); + + IntVect const& subngout = m_sub_helper.make_iv(ngout); + Periodicity const& subperiod = m_sub_helper.make_periodicity(period); + GpuArray const& order = m_sub_helper.xyz_order(); + m_r2x_sub->template forwardThenBackward_doit_1 + (inmf_sub, outmf_sub, + [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& sp) + { + GpuArray idx{i,j,k}; + post_forward(idx[order[0]], idx[order[1]], idx[order[2]], sp); + }, + subngout, subperiod); + + if (!outmf_safe) { + outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); } } else + { + this->template forwardThenBackward_doit_1(inmf, outmf, post_forward, ngout, period); + } +} + +template +template +void R2X::forwardThenBackward_doit_1 (MF const& inmf, MF& outmf, + F const& post_forward, + IntVect const& ngout, + Periodicity const& period) +{ + BL_PROFILE("FFT::R2X::forwardbackward_1"); + + if (m_r2x_sub) { + amrex::Abort("R2X::forwardThenBackward_doit_1: How did this happen?"); + } + else { this->forward(inmf); diff --git a/Tests/LinearSolvers/ABecLap_SP/CMakeLists.txt b/Tests/LinearSolvers/ABecLap_SP/CMakeLists.txt index 2de763f5f6a..5da80399c63 100644 --- a/Tests/LinearSolvers/ABecLap_SP/CMakeLists.txt +++ b/Tests/LinearSolvers/ABecLap_SP/CMakeLists.txt @@ -1,3 +1,7 @@ +if (WIN32 AND (AMReX_GPU_BACKEND STREQUAL "CUDA")) + return() +endif() + foreach(D IN LISTS AMReX_SPACEDIM) if (D EQUAL 1) continue()