Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FFT safe for slabs #4268

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Src/Base/AMReX_Periodicity.H
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public:
//! Cell-centered domain Box "infinitely" long in non-periodic directions.
[[nodiscard]] Box Domain () const noexcept;

[[nodiscard]] IntVect const& intVect () const { return period; }

[[nodiscard]] std::vector<IntVect> shiftIntVect (IntVect const& nghost = IntVect(0)) const;

static const Periodicity& NonPeriodic () noexcept;
Expand Down
231 changes: 231 additions & 0 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,235 @@ void hip_execute (rocfft_plan plan, void **in, void **out)
}
#endif

SubHelper::SubHelper (Box const& domain)
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(domain);
#elif (AMREX_SPACEDIM == 2)
if (domain.length(0) == 1) {
m_case = case_1n;
}
#else
if (domain.length(0) == 1 && domain.length(1) == 1) {
m_case = case_11n;
} else if (domain.length(0) == 1 && domain.length(2) == 1) {
m_case = case_1n1;
} else if (domain.length(0) == 1) {
m_case = case_1nn;
} else if (domain.length(1) == 1) {
m_case = case_n1n;
}
#endif
}

Box SubHelper::make_box (Box const& box) const
{
return Box(make_iv(box.smallEnd()), make_iv(box.bigEnd()), box.ixType());
}

Periodicity SubHelper::make_periodicity (Periodicity const& period) const
{
return Periodicity(make_iv(period.intVect()));
}

bool SubHelper::ghost_safe (IntVect const& ng) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(ng,this);
return true;
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return (ng[0] == 0);
} else {
return true;
}
#else
if (m_case == case_11n) {
return (ng[0] == 0) && (ng[1] == 0);
} else if (m_case == case_1n1) {
return (ng[0] == 0);
} else if (m_case == case_1nn) {
return (ng[0] == 0);
} else if (m_case == case_n1n) {
return (ng[1] == 0);
} else {
return true;
}
#endif
}

IntVect SubHelper::make_iv (IntVect const& iv) const
{
return this->make_array(iv);
}

IntVect SubHelper::make_safe_ghost (IntVect const& ng) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return ng;
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return IntVect{0,ng[1]};
} else {
return ng;
}
#else
if (m_case == case_11n) {
return IntVect{0,0,ng[2]};
} else if (m_case == case_1n1) {
return IntVect{0,ng[1],ng[2]};
} else if (m_case == case_1nn) {
return IntVect{0,ng[1],ng[2]};
} else if (m_case == case_n1n) {
return IntVect{ng[0],0,ng[2]};
} else {
return ng;
}
#endif
}

BoxArray SubHelper::inverse_boxarray (BoxArray const& ba) const
{ // sub domain order -> original domain order
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return ba;
#elif (AMREX_SPACEDIM == 2)
AMREX_ALWAYS_ASSERT(m_case == case_1n);
BoxList bl = ba.boxList();
// sub domain order: y, x
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[1],lo[0]));
b.setBig (IntVect(hi[1],hi[0]));
}
return BoxArray(std::move(bl));
#else
BoxList bl = ba.boxList();
if (m_case == case_11n) {
// sub domain order: z, x, y
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[1],lo[2],lo[0]));
b.setBig (IntVect(hi[1],hi[2],hi[0]));
}
} else if (m_case == case_1n1) {
// sub domain order: y, x, z
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[1],lo[0],lo[2]));
b.setBig (IntVect(hi[1],hi[0],hi[2]));
}
} else if (m_case == case_1nn) {
// sub domain order: y, z, x
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[2],lo[0],lo[1]));
b.setBig (IntVect(hi[2],hi[0],hi[1]));
}
} else if (m_case == case_n1n) {
// sub domain order: x, z, y
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[0],lo[2],lo[1]));
b.setBig (IntVect(hi[0],hi[2],hi[1]));
}
} else {
amrex::Abort("SubHelper::inverse_boxarray: how did this happen?");
}
return BoxArray(std::move(bl));
#endif
}

IntVect SubHelper::inverse_order (IntVect const& order) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return order;
#elif (AMREX_SPACEDIM == 2)
amrex::ignore_unused(this);
return IntVect(order[1],order[0]);
#else
auto translate = [&] (int index) -> int
{
int r = index;
if (m_case == case_11n) {
// sub domain order: z, x, y
if (index == 0) {
r = 2;
} else if (index == 1) {
r = 0;
} else {
r = 1;
}
} else if (m_case == case_1n1) {
// sub domain order: y, x, z
if (index == 0) {
r = 1;
} else if (index == 1) {
r = 0;
} else {
r = 2;
}
} else if (m_case == case_1nn) {
// sub domain order: y, z, x
if (index == 0) {
r = 1;
} else if (index == 1) {
r = 2;
} else {
r = 0;
}
} else if (m_case == case_n1n) {
// sub domain order: x, z, y
if (index == 0) {
r = 0;
} else if (index == 1) {
r = 2;
} else {
r = 1;
}
}
return r;
};

IntVect iv;
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
iv[idim] = translate(order[idim]);
}
return iv;
#endif
}

GpuArray<int,3> SubHelper::xyz_order () const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return GpuArray<int,3>{0,1,2};
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return GpuArray<int,3>{1,0,2};
} else {
return GpuArray<int,3>{0,1,2};
}
#else
if (m_case == case_11n) {
return GpuArray<int,3>{1,2,0};
} else if (m_case == case_1n1) {
return GpuArray<int,3>{1,0,2};
} else if (m_case == case_1nn) {
return GpuArray<int,3>{2,0,1};
} else if (m_case == case_n1n) {
return GpuArray<int,3>{0,2,1};
} else {
return GpuArray<int,3>{0,1,2};
}
#endif
}

}
79 changes: 79 additions & 0 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
#include <AMReX_DataAllocator.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_Enum.H>
#include <AMReX_FabArray.H>
#include <AMReX_Gpu.H>
#include <AMReX_GpuComplex.H>
#include <AMReX_Math.H>
#include <AMReX_Periodicity.H>

#if defined(AMREX_USE_CUDA)
# include <cufft.h>
Expand Down Expand Up @@ -1447,6 +1449,83 @@ struct RotateBwd
}
};

namespace detail
{
struct SubHelper
{
explicit SubHelper (Box const& domain);

[[nodiscard]] Box make_box (Box const& box) const;

[[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;

[[nodiscard]] bool ghost_safe (IntVect const& ng) const;

// This rearranges the order.
[[nodiscard]] IntVect make_iv (IntVect const& iv) const;

// This keeps the order, but zero out the values in the hidden dimension.
[[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;

[[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;

[[nodiscard]] IntVect inverse_order (IntVect const& order) const;

template <typename T>
[[nodiscard]] T make_array (T const& a) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return a;
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return T{a[1],a[0]};
} else {
return a;
}
#else
if (m_case == case_11n) {
return T{a[2],a[0],a[1]};
} else if (m_case == case_1n1) {
return T{a[1],a[0],a[2]};
} else if (m_case == case_1nn) {
return T{a[1],a[2],a[0]};
} else if (m_case == case_n1n) {
return T{a[0],a[2],a[1]};
} else {
return a;
}
#endif
}

[[nodiscard]] GpuArray<int,3> xyz_order () const;

template <typename FA>
FA make_alias_mf (FA const& mf)
{
BoxList bl = mf.boxArray().boxList();
for (auto& b : bl) {
b = make_box(b);
}
auto const& ng = make_iv(mf.nGrowVect());
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
using FAB = typename FA::fab_type;
for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
}
return submf;
}

#if (AMREX_SPACEDIM == 2)
enum Case { case_1n, case_other };
int m_case = case_other;
#elif (AMREX_SPACEDIM == 3)
enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
int m_case = case_other;
#endif
};
}

}

#endif
Loading
Loading