Skip to content

Commit

Permalink
Fix GPU Compile
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Nov 23, 2023
1 parent e05117b commit b448f6c
Showing 1 changed file with 40 additions and 30 deletions.
70 changes: 40 additions & 30 deletions src/particles/CollectLost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,42 @@

namespace impactx
{
struct CopyAndMarkNegative
{
static constexpr int s_index = 0; //!< index of runtime attribute in destination for position s where particle got lost
amrex::ParticleReal s_lost; //!< position s in meters where particle got lost

using SrcData = ImpactXParticleContainer::ParticleTileType::ConstParticleTileDataType;
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;

AMREX_GPU_HOST_DEVICE
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
dst.m_aos[dst_ip] = src.m_aos[src_ip];

for (int j = 0; j < SrcData::NAR; ++j)
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
for (int j = 0; j < src.m_num_runtime_real; ++j)
dst.m_runtime_rdata[j][dst_ip] = src.m_runtime_rdata[j][src_ip];

// unused: integer compile-time or runtime attributes
//for (int j = 0; j < SrcData::NAI; ++j)
// dst.m_idata[j][dst_ip] = src.m_idata[j][src_ip];
//for (int j = 0; j < src.m_num_runtime_int; ++j)
// dst.m_runtime_idata[j][dst_ip] = src.m_runtime_idata[j][src_ip];

// flip id to positive in destination
dst.id(dst_ip) = amrex::Math::abs(dst.id(dst_ip));

// remember the current s of the ref particle when lost
dst.m_runtime_rdata[s_index][dst_ip] = s_lost;
}
};

void collect_lost_particles (ImpactXParticleContainer& source)
{
BL_PROFILE("impactX::collect_lost_particles");

using SrcData = ImpactXParticleContainer::ParticleTileType::ConstParticleTileDataType;
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;

ImpactXParticleContainer& dest = *source.GetLostParticleContainer();

Expand Down Expand Up @@ -52,7 +82,8 @@ namespace impactx
if (np == 0) continue; // no particles in source tile

// we will copy particles that were marked as lost, with a negative id
auto predicate = [] AMREX_GPU_HOST_DEVICE (const SrcData& src, int ip, const amrex::RandomEngine& /*engine*/) noexcept
auto const predicate = [] AMREX_GPU_HOST_DEVICE (const SrcData& src, int ip)
/* NVCC 11.3.109 chokes in C++17 on this: noexcept */
{
return src.id(ip) < 0;
};
Expand All @@ -64,11 +95,11 @@ namespace impactx
amrex::ReduceOps<amrex::ReduceOpSum> reduce_op;
amrex::ReduceData<int> reduce_data(reduce_op);
{
const auto src_data = ptile_source.getConstParticleTileData();
auto const src_data = ptile_source.getConstParticleTileData();

const amrex::RandomEngine rng{}; // unused
reduce_op.eval(np, reduce_data, [=] AMREX_GPU_HOST_DEVICE (int ip) {
return predicate(src_data, ip, rng) ? 1 : 0;
reduce_op.eval(np, reduce_data, [=] AMREX_GPU_HOST_DEVICE (int ip)
{
return predicate(src_data, ip);
});
}
int const np_to_move = amrex::get<0>(reduce_data.value());
Expand All @@ -83,35 +114,14 @@ namespace impactx
AMREX_ALWAYS_ASSERT(SrcData::NAI == 0);
AMREX_ALWAYS_ASSERT(ptile_source.NumRuntimeIntComps() == 0);

// first runtime attribute in destination is s position when particle got lost
int const s_index = dest.NumRuntimeRealComps() - 1;
auto copy_and_mark_negative = [&s_index, &s_lost](DstData& dst, const SrcData& src, int src_ip, int dst_ip) noexcept
{
dst.m_aos[dst_ip] = src.m_aos[src_ip];

for (int j = 0; j < SrcData::NAR; ++j)
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
for (int j = 0; j < src.m_num_runtime_real; ++j)
dst.m_runtime_rdata[j][dst_ip] = src.m_runtime_rdata[j][src_ip];

// unused: integer compile-time or runtime attributes
//for (int j = 0; j < SrcData::NAI; ++j)
// dst.m_idata[j][dst_ip] = src.m_idata[j][src_ip];
//for (int j = 0; j < src.m_num_runtime_int; ++j)
// dst.m_runtime_idata[j][dst_ip] = src.m_runtime_idata[j][src_ip];

// flip id to positive in destination
dst.id(dst_ip) = amrex::Math::abs(dst.id(dst_ip));

// remember the current s of the ref particle when lost
dst.m_runtime_rdata[s_index][dst_ip] = s_lost;
};
// first runtime attribute in destination is s position where particle got lost
AMREX_ALWAYS_ASSERT(dest.NumRuntimeRealComps() > 0);

amrex::filterAndTransformParticles(
ptile_dest,
ptile_source,
predicate,
copy_and_mark_negative,
CopyAndMarkNegative{s_lost},
0,
dst_index
);
Expand Down

0 comments on commit b448f6c

Please sign in to comment.