Skip to content

Commit

Permalink
Replace make_load_iterator forward declaration with file include du…
Browse files Browse the repository at this point in the history
…e to missing definition when using NVRTC, and make changes to included thrust headers to make them NVRTC compilable
  • Loading branch information
NaderAlAwar committed Jan 31, 2025
1 parent 3a631e5 commit e8cd15f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 26 deletions.
12 changes: 1 addition & 11 deletions cub/cub/device/dispatch/kernels/merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,7 @@
#include <cub/util_policy_wrapper_t.cuh>
#include <cub/util_vsmem.cuh>

THRUST_NAMESPACE_BEGIN

namespace cuda_cub::core::detail
{
// We must forward declare here because make_load_iterator.h pulls in non NVRTC compilable code
template <class PtxPlan, class It>
typename detail::LoadIterator<PtxPlan, It>::type _CCCL_DEVICE _CCCL_FORCEINLINE
make_load_iterator(PtxPlan const&, It it);
} // namespace cuda_cub::core::detail

THRUST_NAMESPACE_END
#include <thrust/system/cuda/detail/core/make_load_iterator.h>

CUB_NAMESPACE_BEGIN

Expand Down
17 changes: 9 additions & 8 deletions thrust/thrust/detail/type_traits/pointer_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
#include <thrust/detail/type_traits/is_thrust_pointer.h>
#include <thrust/iterator/iterator_traits.h>

#include <cstddef>
#include <type_traits>
#include <cuda/std/cstddef>
#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN
namespace detail
Expand Down Expand Up @@ -62,7 +62,7 @@ struct pointer_difference
template <typename T>
struct pointer_difference<T*>
{
using type = std::ptrdiff_t;
using type = ::cuda::std::ptrdiff_t;
};

template <typename Ptr, typename T>
Expand Down Expand Up @@ -117,10 +117,10 @@ template <template <typename, typename, typename, typename...> class Ptr,
typename Tag,
typename... PtrTail,
typename T>
struct rebind_pointer<Ptr<OldT, Tag, typename std::add_lvalue_reference<OldT>::type, PtrTail...>, T>
struct rebind_pointer<Ptr<OldT, Tag, typename ::cuda::std::add_lvalue_reference<OldT>::type, PtrTail...>, T>
{
// static_assert(std::is_same<OldT, Tag>::value, "2");
using type = Ptr<T, Tag, typename std::add_lvalue_reference<T>::type, PtrTail...>;
using type = Ptr<T, Tag, typename ::cuda::std::add_lvalue_reference<T>::type, PtrTail...>;
};

// Rebind `thrust::pointer`-like things with native reference types and templated
Expand All @@ -131,11 +131,12 @@ template <template <typename, typename, typename, typename...> class Ptr,
template <typename...> class DerivedPtr,
typename... DerivedPtrTail,
typename T>
struct rebind_pointer<Ptr<OldT, Tag, typename std::add_lvalue_reference<OldT>::type, DerivedPtr<OldT, DerivedPtrTail...>>,
T>
struct rebind_pointer<
Ptr<OldT, Tag, typename ::cuda::std::add_lvalue_reference<OldT>::type, DerivedPtr<OldT, DerivedPtrTail...>>,
T>
{
// static_assert(std::is_same<OldT, Tag>::value, "3");
using type = Ptr<T, Tag, typename std::add_lvalue_reference<T>::type, DerivedPtr<T, DerivedPtrTail...>>;
using type = Ptr<T, Tag, typename ::cuda::std::add_lvalue_reference<T>::type, DerivedPtr<T, DerivedPtrTail...>>;
};

namespace pointer_traits_detail
Expand Down
16 changes: 9 additions & 7 deletions thrust/thrust/iterator/iterator_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
# pragma system_header
#endif // no system header

#include <iterator>
#include <cuda/std/iterator>

THRUST_NAMESPACE_BEGIN

/*! \p iterator_traits is a type trait class that provides a uniform
* interface for querying the properties of iterators at compile-time.
*/
template <typename T>
struct iterator_traits : std::iterator_traits<T>
struct iterator_traits : ::cuda::std::iterator_traits<T>
{};

template <typename Iterator>
Expand All @@ -70,8 +70,10 @@ struct iterator_system;

THRUST_NAMESPACE_END

#include <thrust/iterator/detail/any_system_tag.h>
#include <thrust/iterator/detail/device_system_tag.h>
#include <thrust/iterator/detail/host_system_tag.h>
#include <thrust/iterator/detail/iterator_traits.inl>
#include <thrust/iterator/detail/iterator_traversal_tags.h>
#if !_CCCL_COMPILER(NVRTC)
# include <thrust/iterator/detail/any_system_tag.h>
# include <thrust/iterator/detail/device_system_tag.h>
# include <thrust/iterator/detail/host_system_tag.h>
# include <thrust/iterator/detail/iterator_traits.inl>
# include <thrust/iterator/detail/iterator_traversal_tags.h>
#endif

0 comments on commit e8cd15f

Please sign in to comment.