diff --git a/Src/FFT/AMReX_FFT.cpp b/Src/FFT/AMReX_FFT.cpp index 20a4f1ad06..82f201b354 100644 --- a/Src/FFT/AMReX_FFT.cpp +++ b/Src/FFT/AMReX_FFT.cpp @@ -9,8 +9,11 @@ namespace amrex::FFT namespace { bool s_initialized = false; - std::map s_plans_d; +#ifdef AMREX_USE_FLOAT std::map s_plans_f; +#else + std::map s_plans_d; +#endif } void Initialize () @@ -43,42 +46,48 @@ void Finalize () void Clear () { - for (auto& [k, p] : s_plans_d) { - Plan::destroy_vendor_plan(p); - } - +#ifdef AMREX_USE_FLOAT for (auto& [k, p] : s_plans_f) { Plan::destroy_vendor_plan(p); } +#else + for (auto& [k, p] : s_plans_d) { + Plan::destroy_vendor_plan(p); + } +#endif } -PlanD* get_vendor_plan_d (Key const& key) +#ifdef AMREX_USE_FLOAT +PlanF* get_vendor_plan_f (Key const& key) { - if (auto found = s_plans_d.find(key); found != s_plans_d.end()) { + if (auto found = s_plans_f.find(key); found != s_plans_f.end()) { return &(found->second); } else { return nullptr; } } - -PlanF* get_vendor_plan_f (Key const& key) +#else +PlanD* get_vendor_plan_d (Key const& key) { - if (auto found = s_plans_f.find(key); found != s_plans_f.end()) { + if (auto found = s_plans_d.find(key); found != s_plans_d.end()) { return &(found->second); } else { return nullptr; } } +#endif -void add_vendor_plan_d (Key const& key, PlanD plan) -{ - s_plans_d[key] = plan; -} - +#ifdef AMREX_USE_FLOAT void add_vendor_plan_f (Key const& key, PlanF plan) { s_plans_f[key] = plan; } +#else +void add_vendor_plan_d (Key const& key, PlanD plan) +{ + s_plans_d[key] = plan; +} +#endif } diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index a0783dfac5..94451c68e1 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -1132,14 +1132,16 @@ struct Plan }; using Key = std::tuple,Direction,Kind>; -using PlanD = typename Plan::VendorPlan; -using PlanF = typename Plan::VendorPlan; -PlanD* get_vendor_plan_d (Key const& key); +#ifdef AMREX_USE_FLOAT +using PlanF = typename Plan::VendorPlan; PlanF* get_vendor_plan_f (Key const& key); - -void add_vendor_plan_d (Key const& key, PlanD plan); void add_vendor_plan_f (Key const& key, PlanF plan); +#else +using PlanD = typename Plan::VendorPlan; +PlanD* get_vendor_plan_d (Key const& key); +void add_vendor_plan_d (Key const& key, PlanD plan); +#endif template template @@ -1160,11 +1162,15 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool Key key = {fft_size.template expand<3>(), D, kind}; if (cache) { VendorPlan* cached_plan = nullptr; - if constexpr (std::is_same_v) { +#ifdef AMREX_USE_FLOAT + if constexpr (std::is_same_v) { cached_plan = get_vendor_plan_f(key); - } else { + } +#else + if constexpr (std::is_same_v) { cached_plan = get_vendor_plan_d(key); } +#endif if (cached_plan) { plan = *cached_plan; return; @@ -1288,11 +1294,15 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool #if defined(AMREX_USE_GPU) if (cache) { - if constexpr (std::is_same_v) { +#ifdef AMREX_USE_FLOAT + if constexpr (std::is_same_v) { add_vendor_plan_f(key, plan); - } else { + } +#else + if constexpr (std::is_same_v) { add_vendor_plan_d(key, plan); } +#endif } #endif }