Skip to content

Commit

Permalink
Split the type list.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 6, 2025
1 parent 8e7c2bb commit 40399a8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/encoder/ordinal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,21 @@ struct CatStrArrayView {
return this->offsets.size_bytes() + values.size_bytes();
}
};

// We keep a single type list here for supported types and use various transformations to
// add specializations. This way we can modify the type list with ease.

/**
* @brief All the primitive types supported by the encoder.
*/
using CatPrimIndexTypes =
std::tuple<std::int8_t, std::int16_t, std::int32_t, std::int64_t, float, double>;

/**
* @brief All the types supported by the encoder.
* @brief All the column types supported by the encoder.
*/
using CatIndexViewTypes =
std::tuple<enc::CatStrArrayView, Span<std::int8_t const>, Span<std::int16_t const>,
Span<std::int32_t const>, Span<std::int64_t const>, Span<float const>,
Span<double const>>;
using CatIndexViewTypes = decltype(std::tuple_cat(std::tuple<enc::CatStrArrayView>{},
PrimToSpan<CatPrimIndexTypes>::Type{}));

/**
* @brief Host categories view for a single column.
Expand Down
20 changes: 20 additions & 0 deletions src/encoder/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <tuple> // for tuple
#include <variant> // for variant

#include "xgboost/span.h" // for Span

#if defined(XGBOOST_USE_CUDA)

#include <cuda/std/variant> // for variant
Expand All @@ -27,7 +29,24 @@ struct Overloaded : Ts... {
template <typename... Ts>
ENC_DEVICE Overloaded(Ts...) -> Overloaded<Ts...>;

// Whether a type is a member of a type list (a.k.a tuple).
template <typename... Ts>
struct MemberOf;

template <typename T, typename... Ts>
struct MemberOf<T, std::tuple<Ts...>> : public std::disjunction<std::is_same<T, Ts>...> {};

// Convert primitive types to span types.
template <typename... Ts>
struct PrimToSpan;

template <typename... Ts>
struct PrimToSpan<std::tuple<Ts...>> {
using Type = std::tuple<xgboost::common::Span<std::add_const_t<Ts>>...>;
};

namespace cpu_impl {
// Convert tuple of types to variant of types.
template <typename... Ts>
struct TupToVar;

Expand All @@ -42,6 +61,7 @@ using TupToVarT = typename TupToVar<Ts...>::Type;

#if defined(XGBOOST_USE_CUDA)
namespace cuda_impl {
// Convert tuple of types to CUDA variant of types.
template <typename... Ts>
struct TupToVar {};

Expand Down

0 comments on commit 40399a8

Please sign in to comment.