Skip to content

Commit

Permalink
Enable Build with FFTW in One Precision
Browse files Browse the repository at this point in the history
FFTW can be built in multiple precisions, but we will
only use one at a time with `amrex::Real`. Pick the one
we really use, do not require more than that.

This will simplify specifications and requirements in package
management.
  • Loading branch information
ax3l committed Jan 6, 2025
1 parent 393674b commit b4fb4bc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
39 changes: 24 additions & 15 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ namespace amrex::FFT
namespace
{
bool s_initialized = false;
std::map<Key, PlanD> s_plans_d;
#ifdef AMREX_USE_FLOAT
std::map<Key, PlanF> s_plans_f;
#else
std::map<Key, PlanD> s_plans_d;
#endif
}

void Initialize ()
Expand Down Expand Up @@ -43,42 +46,48 @@ void Finalize ()

void Clear ()
{
for (auto& [k, p] : s_plans_d) {
Plan<double>::destroy_vendor_plan(p);
}

#ifdef AMREX_USE_FLOAT
for (auto& [k, p] : s_plans_f) {
Plan<float>::destroy_vendor_plan(p);
}
#else
for (auto& [k, p] : s_plans_d) {
Plan<double>::destroy_vendor_plan(p);
}
#endif
}

PlanD* get_vendor_plan_d (Key const& key)
#ifdef AMREX_USE_FLOAT
PlanF* get_vendor_plan_f (Key const& key)
{
if (auto found = s_plans_d.find(key); found != s_plans_d.end()) {
if (auto found = s_plans_f.find(key); found != s_plans_f.end()) {
return &(found->second);
} else {
return nullptr;
}
}

PlanF* get_vendor_plan_f (Key const& key)
#else
PlanD* get_vendor_plan_d (Key const& key)
{
if (auto found = s_plans_f.find(key); found != s_plans_f.end()) {
if (auto found = s_plans_d.find(key); found != s_plans_d.end()) {
return &(found->second);
} else {
return nullptr;
}
}
#endif

void add_vendor_plan_d (Key const& key, PlanD plan)
{
s_plans_d[key] = plan;
}

#ifdef AMREX_USE_FLOAT
void add_vendor_plan_f (Key const& key, PlanF plan)
{
s_plans_f[key] = plan;
}
#else
void add_vendor_plan_d (Key const& key, PlanD plan)
{
s_plans_d[key] = plan;
}
#endif

}

Expand Down
28 changes: 19 additions & 9 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -1132,14 +1132,16 @@ struct Plan
};

using Key = std::tuple<IntVectND<3>,Direction,Kind>;
using PlanD = typename Plan<double>::VendorPlan;
using PlanF = typename Plan<float>::VendorPlan;

PlanD* get_vendor_plan_d (Key const& key);
#ifdef AMREX_USE_FLOAT
using PlanF = typename Plan<float>::VendorPlan;
PlanF* get_vendor_plan_f (Key const& key);

void add_vendor_plan_d (Key const& key, PlanD plan);
void add_vendor_plan_f (Key const& key, PlanF plan);
#else
using PlanD = typename Plan<double>::VendorPlan;
PlanD* get_vendor_plan_d (Key const& key);
void add_vendor_plan_d (Key const& key, PlanD plan);
#endif

template <typename T>
template <Direction D, int M>
Expand All @@ -1160,11 +1162,15 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
Key key = {fft_size.template expand<3>(), D, kind};
if (cache) {
VendorPlan* cached_plan = nullptr;
if constexpr (std::is_same_v<float,T>) {
#ifdef AMREX_USE_FLOAT
if constexpr (std::is_same_v<float, T>) {
cached_plan = get_vendor_plan_f(key);
} else {
}
#else
if constexpr (std::is_same_v<double, T>) {
cached_plan = get_vendor_plan_d(key);
}
#endif
if (cached_plan) {
plan = *cached_plan;
return;
Expand Down Expand Up @@ -1288,11 +1294,15 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool

#if defined(AMREX_USE_GPU)
if (cache) {
if constexpr (std::is_same_v<float,T>) {
#ifdef AMREX_USE_FLOAT
if constexpr (std::is_same_v<float, T>) {
add_vendor_plan_f(key, plan);
} else {
}
#else
if constexpr (std::is_same_v<double, T>) {
add_vendor_plan_d(key, plan);
}
#endif
}
#endif
}
Expand Down

0 comments on commit b4fb4bc

Please sign in to comment.