Skip to content

Commit

Permalink
Only test half/bfloat16 when libcu++ supports them
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 30, 2025
1 parent b1f2e63 commit d8014bf
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 38 deletions.
16 changes: 8 additions & 8 deletions c2h/generators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,15 @@ template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, float* out, std::size_t element_size);
template void init_key_segments(
const c2h::device_vector<std::uint32_t>& segment_offsets, custom_type_state_t* out, std::size_t element_size);
#ifdef _CCCL_HAS_NVFP16
#ifdef TEST_HALF_T
template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, half_t* out, std::size_t element_size);
#endif // _CCCL_HAS_NVFP16
#endif // TEST_HALF_T

#ifdef _CCCL_HAS_NVBF16
#ifdef TEST_BF_T
template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, bfloat16_t* out, std::size_t element_size);
#endif // _CCCL_HAS_NVBF16
#endif // TEST_BF_T
} // namespace detail

template <typename T>
Expand Down Expand Up @@ -552,15 +552,15 @@ INSTANTIATE(double);
INSTANTIATE(bool);
INSTANTIATE(char);

#ifdef _CCCL_HAS_NVFP16
#ifdef TEST_HALF_T
INSTANTIATE(half_t);
INSTANTIATE(__half);
#endif // _CCCL_HAS_NVFP16
#endif // TEST_HALF_T

#ifdef _CCCL_HAS_NVBF16
#ifdef TEST_BF_T
INSTANTIATE(bfloat16_t);
INSTANTIATE(__nv_bfloat16);
#endif // _CCCL_HAS_NVBF16
#endif // TEST_BF_T

#undef INSTANTIATE_RND
#undef INSTANTIATE_MOD
Expand Down
8 changes: 6 additions & 2 deletions c2h/include/c2h/extended_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@
#include <cuda/__cccl_config>

#ifndef TEST_HALF_T
# define TEST_HALF_T _CCCL_HAS_NVFP16
# if defined(_CCCL_HAS_NVFP16) && defined(_LIBCUDACXX_HAS_NVFP16)
# define TEST_HALF_T
# endif
#endif

#ifndef TEST_BF_T
# define TEST_BF_T _CCCL_HAS_NVBF16
# if defined(_CCCL_HAS_NVBF16) && defined(_LIBCUDACXX_HAS_NVBF16)
# define TEST_BF_T
# endif
#endif

#ifdef TEST_HALF_T
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_segmented_sort_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,15 @@ struct unwrap_value_t_impl
using type = T;
};

#if TEST_HALF_T
#ifdef TEST_HALF_T
template <>
struct unwrap_value_t_impl<half_t>
{
using type = __half;
};
#endif

#if TEST_BF_T
#ifdef TEST_BF_T
template <>
struct unwrap_value_t_impl<bfloat16_t>
{
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ auto cast_if_half_pointer(T* p) -> T*
return p;
}

#if TEST_HALF_T
#ifdef TEST_HALF_T
auto cast_if_half_pointer(half_t* p) -> __half*
{
return reinterpret_cast<__half*>(p);
Expand Down Expand Up @@ -412,7 +412,7 @@ using types =
std::uint32_t,
std::int64_t,
std::uint64_t,
#if TEST_HALF_T
#ifdef TEST_HALF_T
half_t,
#endif
float,
Expand Down
13 changes: 6 additions & 7 deletions cub/test/catch2_test_device_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,13 @@ using full_type_list = c2h::type_list<type_pair<uchar3>, type_pair<ulonglong4>>;
// clang-format off
using full_type_list = c2h::type_list<
type_pair<custom_t>
#if TEST_HALF_T
, type_pair<half_t> // testing half
#endif
#if TEST_BF_T
, type_pair<bfloat16_t> // testing bf16

#ifdef TEST_HALF_T
, type_pair<half_t>
#endif // TEST_HALF_T
#ifdef TEST_BF_T
, type_pair<bfloat16_t>
#endif // TEST_BF_T
>;
#endif
// clang-format on
#elif TEST_TYPES == 4
// DPX SIMD instructions
Expand Down
15 changes: 10 additions & 5 deletions cub/test/catch2_test_device_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#include <c2h/test_util_vec.h>
#include <nv/target>

#if TEST_HALF_T
#ifdef TEST_HALF_T
// Half support is provided by SM53+. We currently test against a few older architectures.
// The specializations below can be removed once we drop these architectures.

Expand Down Expand Up @@ -107,8 +107,13 @@ __host__ __device__ __forceinline__ //

return a;
}

CUB_NAMESPACE_END

#endif // TEST_HALF_T

CUB_NAMESPACE_BEGIN

/**
* @brief Introduces the required NumericTraits for `c2h::custom_type_t`.
*/
Expand Down Expand Up @@ -173,15 +178,15 @@ struct ExtendedFloatSum
return result;
}

#if TEST_HALF_T
#ifdef TEST_HALF_T
__host__ __device__ __half operator()(__half a, __half b) const
{
uint16_t result = this->operator()(half_t{a}, half_t(b)).raw();
return reinterpret_cast<__half&>(result);
}
#endif

#if TEST_BF_T
#ifdef TEST_BF_T
__device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const
{
uint16_t result = this->operator()(bfloat16_t{a}, bfloat16_t(b)).raw();
Expand All @@ -196,7 +201,7 @@ inline It unwrap_it(It it)
return it;
}

#if TEST_HALF_T
#ifdef TEST_HALF_T
inline __half* unwrap_it(half_t* it)
{
return reinterpret_cast<__half*>(it);
Expand All @@ -211,7 +216,7 @@ inline thrust::constant_iterator<__half, OffsetT> unwrap_it(thrust::constant_ite
}
#endif

#if TEST_BF_T
#ifdef TEST_BF_T
inline __nv_bfloat16* unwrap_it(bfloat16_t* it)
{
return reinterpret_cast<__nv_bfloat16*>(it);
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_reduce_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ using full_type_list = c2h::type_list<type_triple<uchar3, uchar3, custom_t>, typ
// clang-format off
using full_type_list = c2h::type_list<
type_triple<custom_t>
#if TEST_HALF_T
#ifdef TEST_HALF_T
, type_triple<half_t> // testing half
#endif
#if TEST_BF_T
#ifdef TEST_BF_T
, type_triple<bfloat16_t> // testing bf16
#endif
>;
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ using full_type_list = c2h::type_list<type_pair<uchar3>, type_pair<ulonglong4>>;
// clang-format off
using full_type_list = c2h::type_list<
type_pair<custom_t>
#if TEST_HALF_T
#ifdef TEST_HALF_T
, type_pair<half_t> // testing half
#endif
#if TEST_BF_T
#ifdef TEST_BF_T
, type_pair<bfloat16_t> // testing bf16
#endif
>;
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_scan_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ using full_type_list =
// clang-format off
using full_type_list = c2h::type_list<
type_quad<custom_t, custom_t, custom_t>
#if TEST_HALF_T
#ifdef TEST_HALF_T
, type_quad<half_t> // testing half
#endif
#if TEST_BF_T
#ifdef TEST_BF_T
, type_quad<bfloat16_t> // testing bf16
#endif
>;
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_segmented_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ using full_type_list = c2h::type_list<type_pair<uchar3>, type_pair<ulonglong4>>;
// clang-format off
using full_type_list = c2h::type_list<
type_pair<custom_t>
#if TEST_HALF_T
#ifdef TEST_HALF_T
, type_pair<half_t> // testing half
#endif
#if TEST_BF_T
#ifdef TEST_BF_T
, type_pair<bfloat16_t> // testing bf16
#endif
>;
Expand Down
11 changes: 9 additions & 2 deletions cub/test/catch2_test_device_segmented_sort_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "insert_nested_NVTX_range_guard.h"
// above header needs to be included first
#include <cub/device/device_segmented_sort.cuh>
#include <cub/util_type.cuh>

#include "catch2_radix_sort_helper.cuh"
#include "catch2_segmented_sort_helper.cuh"
Expand All @@ -37,17 +38,23 @@
// graph launch.
// %PARAM% TEST_LAUNCH lid 0:1

static_assert(::cuda::std::__is_extended_floating_point<__half>::value);
static_assert(::cuda::is_floating_point_v<__half>);

cub::Twiddle<__half>::UnsignedBits a;
cub::Twiddle<half_t>::UnsignedBits b;

DECLARE_LAUNCH_WRAPPER(cub::DeviceSegmentedSort::StableSortKeys, stable_sort_keys);

using key_types =
c2h::type_list<bool,
std::uint8_t,
std::uint64_t
#if TEST_HALF_T
#ifdef TEST_HALF_T
,
half_t
#endif
#if TEST_BF_T
#ifdef TEST_BF_T
,
bfloat16_t
#endif
Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_segmented_sort_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ using pair_types =
c2h::type_list<c2h::type_list<bool, std::uint8_t>,
c2h::type_list<std::int8_t, std::uint64_t>,
c2h::type_list<double, float>
#if TEST_HALF_T
#ifdef TEST_HALF_T
,
c2h::type_list<half_t, std::int8_t>
#endif
#if TEST_BF_T
#ifdef TEST_BF_T
,
c2h::type_list<bfloat16_t, float>
#endif
Expand Down

0 comments on commit d8014bf

Please sign in to comment.