Skip to content

Commit

Permalink
take SYCL implementation from DPCT
Browse files Browse the repository at this point in the history
  • Loading branch information
AuroraPerego committed Dec 11, 2023
1 parent 625884b commit 720be22
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions include/alpaka/warp/WarpGenericSycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,18 @@ namespace alpaka::warp::trait
ALPAKA_ASSERT_OFFLOAD(width > 0);
ALPAKA_ASSERT_OFFLOAD(srcLane >= 0);

/* If width < srcLane the sub-group needs to be split into assumed subdivisions. The first item of each
subdivision has the assumed index 0. The srcLane index is relative to the subdivisions.
Example: If we assume a sub-group size of 32 and a width of 16 we will receive two subdivisions:
The first starts at sub-group index 0 and the second at sub-group index 16. For srcLane = 4 the
first subdivision will access the value at sub-group index 4 and the second at sub-group index 20. */
// /* If width < srcLane the sub-group needs to be split into assumed subdivisions. The first item of
// each
// subdivision has the assumed index 0. The srcLane index is relative to the subdivisions.

// Example: If we assume a sub-group size of 32 and a width of 16 we will receive two subdivisions:
// The first starts at sub-group index 0 and the second at sub-group index 16. For srcLane = 4 the
// first subdivision will access the value at sub-group index 4 and the second at sub-group
// index 20. */
auto const actual_group = warp.m_item_warp.get_sub_group();
auto const actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_group_id = actual_item_id / width;
auto const actual_src_id = static_cast<std::size_t>(srcLane + actual_group_id * width);
auto const src = sycl::id<1>{actual_src_id};

return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
unsigned int const start_index = actual_group.get_local_linear_id() / w * w;
return sycl::select_from_group(actual_group, value, start_index + static_cast<std::uint32_t>(srcLane) % w);
}
};

Expand All @@ -142,15 +141,16 @@ namespace alpaka::warp::trait
std::uint32_t offset, /* must be the same for all work-items in the group */
std::int32_t width)
{
std::int32_t offset_int = static_cast<std::int32_t>(offset);
auto const actual_group = warp.m_item_warp.get_sub_group();
auto actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_group_id = actual_item_id / width;
auto const actual_src_id = actual_item_id - offset_int;
auto const src = actual_src_id >= actual_group_id * width
? sycl::id<1>{static_cast<std::size_t>(actual_src_id)}
: sycl::id<1>{static_cast<std::size_t>(actual_item_id)};
return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
unsigned int const id = actual_group.get_local_linear_id();
unsigned int const start_index = id / w * w;
T result = sycl::shift_group_right(actual_group, value, offset);
if((id - start_index) < offset)
{
result = value;
}
return result;
}
};

Expand All @@ -164,33 +164,31 @@ namespace alpaka::warp::trait
std::uint32_t offset,
std::int32_t width)
{
std::int32_t offset_int = static_cast<std::int32_t>(offset);
auto const actual_group = warp.m_item_warp.get_sub_group();
auto actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_group_id = actual_item_id / width;
auto const actual_src_id = actual_item_id + offset_int;
auto const src = actual_src_id < (actual_group_id + 1) * width
? sycl::id<1>{static_cast<std::size_t>(actual_src_id)}
: sycl::id<1>{static_cast<std::size_t>(actual_item_id)};
return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
unsigned int const id = actual_group.get_local_linear_id();
unsigned int const end_index = (id / w + 1) * w;
T result = sycl::shift_group_left(actual_group, value, offset);
if((id + offset) >= end_index)
{
result = value;
}
return result;
}
};

template<typename TDim>
struct ShflXor<warp::WarpGenericSycl<TDim>>
{
template<typename T>
static auto shfl_xor(
warp::WarpGenericSycl<TDim> const& warp,
T value,
std::int32_t mask,
std::int32_t /*width*/)
static auto shfl_xor(warp::WarpGenericSycl<TDim> const& warp, T value, std::int32_t mask, std::int32_t width)
{
auto const actual_group = warp.m_item_warp.get_sub_group();
auto actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_src_id = actual_item_id ^ mask;
auto const src = sycl::id<1>{static_cast<std::size_t>(actual_src_id)};
return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
unsigned int const id = actual_group.get_local_linear_id();
unsigned int const start_index = id / w * w;
unsigned int const target_offset = (id % w) ^ static_cast<std::uint32_t>(mask);
return sycl::select_from_group(actual_group, value, target_offset < w ? start_index + target_offset : id);
}
};
} // namespace alpaka::warp::trait
Expand Down

0 comments on commit 720be22

Please sign in to comment.