From 191f1cc078bb00e74cf54a58e58f869f933e0748 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Fri, 25 Mar 2022 23:32:18 +0100 Subject: [PATCH] enable swap(...) for temporaries and references --- testing/device_reference.cu | 30 +++++++++++++++++++++++++++--- thrust/detail/reference.h | 34 +++++++++++++++++++++++++++++++--- thrust/device_reference.h | 30 ++++++++++++++++++++++++++++-- 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/testing/device_reference.cu b/testing/device_reference.cu index c30934d75..1fed94b8d 100644 --- a/testing/device_reference.cu +++ b/testing/device_reference.cu @@ -208,6 +208,7 @@ DECLARE_UNITTEST(TestDeviceReferenceManipulation); void TestDeviceReferenceSwap(void) { + using std::swap; typedef int T; thrust::device_vector v(2); @@ -218,14 +219,37 @@ void TestDeviceReferenceSwap(void) ref2 = 13; // test thrust::swap() - thrust::swap(ref1, ref2); + swap(ref1, ref2); ASSERT_EQUAL(13, ref1); ASSERT_EQUAL(7, ref2); + // test thrust::swap(device_reference, device_reference) + swap(v.front(), v.back()); + ASSERT_EQUAL(7, v.front()); + ASSERT_EQUAL(13, v.back()); + // test .swap() ref1.swap(ref2); - ASSERT_EQUAL(7, ref1); - ASSERT_EQUAL(13, ref2); + ASSERT_EQUAL(13, ref1); + ASSERT_EQUAL(7, ref2); + + // test .swap(device_reference) + v.front().swap(v.back()); + ASSERT_EQUAL(7, v.front()); + ASSERT_EQUAL(13, v.back()); + + // test thrust::swap(device_reference, T&) + T val = 29; + swap(v.front(), val); + ASSERT_EQUAL(7, val); + ASSERT_EQUAL(29, v.front()); + ASSERT_EQUAL(13, v.back()); + + // test thrust::swap(T&, device_reference) + swap(val, v.back()); + ASSERT_EQUAL(13, val); + ASSERT_EQUAL(29, v.front()); + ASSERT_EQUAL(7, v.back()); } DECLARE_UNITTEST(TestDeviceReferenceSwap); diff --git a/thrust/detail/reference.h b/thrust/detail/reference.h index 5cc13625d..ab330d314 100644 --- a/thrust/detail/reference.h +++ b/thrust/detail/reference.h @@ -163,7 +163,7 @@ class reference * \param other The \p tagged_reference to swap with. */ __host__ __device__ - void swap(derived_type& other) + void swap(derived_type other) { // Avoid default-constructing a system; instead, just use a null pointer // for dispatch. This assumes that `get_value` will not access any system @@ -372,7 +372,7 @@ class reference template __host__ __device__ - void swap(System* system, derived_type& other) + void swap(System* system, derived_type other) { using thrust::system::detail::generic::select_system; using thrust::system::detail::generic::iter_swap; @@ -509,10 +509,38 @@ class tagged_reference {}; */ template __host__ __device__ -void swap(tagged_reference& x, tagged_reference& y) +void swap(tagged_reference x, tagged_reference y) { x.swap(y); } +/*! Exchanges the values of two objects referred to by a \p tagged_reference and a regular reference. + * + * \param x The \p tagged_reference of interest. + * \param y The regular reference of interest. + */ +template +__host__ __device__ +void swap(Element& x, tagged_reference y) +{ + Element tmp = x; + x = y; + y = tmp; +} + +/*! Exchanges the values of two objects referred to by a regular reference and a \p tagged_reference. + * + * \param x The regular reference of interest. + * \param y The \p tagged_reference of interest. + */ +template +__host__ __device__ +void swap(tagged_reference x, Element& y) +{ + Element tmp = x; + x = y; + y = tmp; +} + THRUST_NAMESPACE_END diff --git a/thrust/device_reference.h b/thrust/device_reference.h index 512ab4c60..b56ca1f32 100644 --- a/thrust/device_reference.h +++ b/thrust/device_reference.h @@ -330,7 +330,7 @@ template * \p other The other \p device_reference with which to swap. */ __host__ __device__ - void swap(device_reference &other); + void swap(device_reference other); /*! Prefix increment operator increments the object referenced by this * \p device_reference. @@ -962,11 +962,37 @@ template */ template __host__ __device__ -void swap(device_reference& x, device_reference& y) +void swap(device_reference x, device_reference y) { x.swap(y); } +/*! swaps the value of a \p device_reference with a regular reference. + * \p x The \p device_reference of interest. + * \p y The regular reference of interest. + */ +template +__host__ __device__ +void swap(device_reference x, T &y) +{ + T tmp = x; + x = y; + y = tmp; +} + +/*! swaps the value of a regular reference with a \p device_reference. + * \p x The regular reference of interest. + * \p y The \p device_reference of interest. + */ +template +__host__ __device__ +void swap(T &x, device_reference y) +{ + T tmp = x; + x = y; + y = tmp; +} + // declare these methods for the purpose of Doxygenating them // they actually are defined for a derived-from class #if THRUST_DOXYGEN