diff --git a/src/parcsr_mv/_hypre_parcsr_mv.h b/src/parcsr_mv/_hypre_parcsr_mv.h index 269db348e9..7481611826 100644 --- a/src/parcsr_mv/_hypre_parcsr_mv.h +++ b/src/parcsr_mv/_hypre_parcsr_mv.h @@ -741,6 +741,8 @@ HYPRE_Int HYPRE_ParVectorPrintBinaryIJ ( HYPRE_ParVector vector, const char *fil HYPRE_Int HYPRE_ParVectorSetConstantValues ( HYPRE_ParVector vector, HYPRE_Complex value ); HYPRE_Int HYPRE_ParVectorSetRandomValues ( HYPRE_ParVector vector, HYPRE_Int seed ); HYPRE_Int HYPRE_ParVectorCopy ( HYPRE_ParVector x, HYPRE_ParVector y ); +HYPRE_Int hypre_ParVectorStridedCopy( hypre_ParVector *x, HYPRE_Int istride, HYPRE_Int ostride, + HYPRE_Int size, HYPRE_Complex *data ); HYPRE_ParVector HYPRE_ParVectorCloneShallow ( HYPRE_ParVector x ); HYPRE_Int HYPRE_ParVectorScale ( HYPRE_Complex value, HYPRE_ParVector x ); HYPRE_Int HYPRE_ParVectorAxpy ( HYPRE_Complex alpha, HYPRE_ParVector x, HYPRE_ParVector y ); diff --git a/src/parcsr_mv/par_vector.c b/src/parcsr_mv/par_vector.c index 77e6afbbdc..5a95156798 100644 --- a/src/parcsr_mv/par_vector.c +++ b/src/parcsr_mv/par_vector.c @@ -372,6 +372,22 @@ hypre_ParVectorCopy( hypre_ParVector *x, return hypre_SeqVectorCopy(x_local, y_local); } +/*-------------------------------------------------------------------------- + * hypre_ParVectorStridedCopy + *--------------------------------------------------------------------------*/ + +HYPRE_Int +hypre_ParVectorStridedCopy( hypre_ParVector *x, + HYPRE_Int istride, + HYPRE_Int ostride, + HYPRE_Int size, + HYPRE_Complex *data) +{ + hypre_Vector *x_local = hypre_ParVectorLocalVector(x); + + return hypre_SeqVectorStridedCopy(x_local, istride, ostride, size, data); +} + /*-------------------------------------------------------------------------- * hypre_ParVectorCloneShallow * diff --git a/src/parcsr_mv/protos.h b/src/parcsr_mv/protos.h index 12fd43f993..fcdaf8e565 100644 --- a/src/parcsr_mv/protos.h +++ b/src/parcsr_mv/protos.h @@ -83,6 +83,8 @@ HYPRE_Int HYPRE_ParVectorPrintBinaryIJ ( HYPRE_ParVector vector, const char *fil HYPRE_Int HYPRE_ParVectorSetConstantValues ( HYPRE_ParVector vector, HYPRE_Complex value ); HYPRE_Int HYPRE_ParVectorSetRandomValues ( HYPRE_ParVector vector, HYPRE_Int seed ); HYPRE_Int HYPRE_ParVectorCopy ( HYPRE_ParVector x, HYPRE_ParVector y ); +HYPRE_Int hypre_ParVectorStridedCopy( hypre_ParVector *x, HYPRE_Int istride, HYPRE_Int ostride, + HYPRE_Int size, HYPRE_Complex *data ); HYPRE_ParVector HYPRE_ParVectorCloneShallow ( HYPRE_ParVector x ); HYPRE_Int HYPRE_ParVectorScale ( HYPRE_Complex value, HYPRE_ParVector x ); HYPRE_Int HYPRE_ParVectorAxpy ( HYPRE_Complex alpha, HYPRE_ParVector x, HYPRE_ParVector y ); diff --git a/src/seq_mv/protos.h b/src/seq_mv/protos.h index 00afd2bef2..f0541629bb 100644 --- a/src/seq_mv/protos.h +++ b/src/seq_mv/protos.h @@ -267,28 +267,23 @@ hypre_Vector *hypre_SeqVectorRead ( char *file_name ); HYPRE_Int hypre_SeqVectorPrint ( hypre_Vector *vector, char *file_name ); HYPRE_Int hypre_SeqVectorSetConstantValues ( hypre_Vector *v, HYPRE_Complex value ); HYPRE_Int hypre_SeqVectorSetConstantValuesHost ( hypre_Vector *v, HYPRE_Complex value ); -HYPRE_Int hypre_SeqVectorSetConstantValuesDevice ( hypre_Vector *v, HYPRE_Complex value ); HYPRE_Int hypre_SeqVectorSetRandomValues ( hypre_Vector *v, HYPRE_Int seed ); HYPRE_Int hypre_SeqVectorCopy ( hypre_Vector *x, hypre_Vector *y ); +HYPRE_Int hypre_SeqVectorStridedCopy( hypre_Vector *x, HYPRE_Int istride, HYPRE_Int ostride, + HYPRE_Int size, HYPRE_Complex *data); hypre_Vector *hypre_SeqVectorCloneDeep ( hypre_Vector *x ); hypre_Vector *hypre_SeqVectorCloneDeep_v2( hypre_Vector *x, HYPRE_MemoryLocation memory_location ); hypre_Vector *hypre_SeqVectorCloneShallow ( hypre_Vector *x ); HYPRE_Int hypre_SeqVectorMigrate( hypre_Vector *x, HYPRE_MemoryLocation memory_location ); HYPRE_Int hypre_SeqVectorScale( HYPRE_Complex alpha, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorScaleHost( HYPRE_Complex alpha, hypre_Vector *y ); -HYPRE_Int hypre_SeqVectorScaleDevice( HYPRE_Complex alpha, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorAxpy ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorAxpyHost ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); -HYPRE_Int hypre_SeqVectorAxpyDevice ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorAxpyz ( HYPRE_Complex alpha, hypre_Vector *x, HYPRE_Complex beta, hypre_Vector *y, hypre_Vector *z ); -HYPRE_Int hypre_SeqVectorAxpyzDevice ( HYPRE_Complex alpha, hypre_Vector *x, - HYPRE_Complex beta, hypre_Vector *y, - hypre_Vector *z ); HYPRE_Real hypre_SeqVectorInnerProd ( hypre_Vector *x, hypre_Vector *y ); HYPRE_Real hypre_SeqVectorInnerProdHost ( hypre_Vector *x, hypre_Vector *y ); -HYPRE_Real hypre_SeqVectorInnerProdDevice ( hypre_Vector *x, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorMassInnerProd(hypre_Vector *x, hypre_Vector **y, HYPRE_Int k, HYPRE_Int unroll, HYPRE_Real *result); HYPRE_Int hypre_SeqVectorMassInnerProd4(hypre_Vector *x, hypre_Vector **y, HYPRE_Int k, @@ -309,8 +304,6 @@ HYPRE_Int hypre_SeqVectorMassAxpy8(HYPRE_Complex *alpha, hypre_Vector **x, hypre HYPRE_Int k); HYPRE_Complex hypre_SeqVectorSumElts ( hypre_Vector *vector ); HYPRE_Complex hypre_SeqVectorSumEltsHost ( hypre_Vector *vector ); -HYPRE_Complex hypre_SeqVectorSumEltsDevice ( hypre_Vector *vector ); -HYPRE_Int hypre_SeqVectorPrefetch(hypre_Vector *x, HYPRE_MemoryLocation memory_location); //HYPRE_Int hypre_SeqVectorMax( HYPRE_Complex alpha, hypre_Vector *x, HYPRE_Complex beta, hypre_Vector *y ); HYPRE_Int hypreDevice_CSRSpAdd(HYPRE_Int ma, HYPRE_Int mb, HYPRE_Int n, HYPRE_Int nnzA, @@ -347,9 +340,6 @@ HYPRE_Int hypre_SeqVectorElmdivpyMarked( hypre_Vector *x, hypre_Vector *b, hypre HYPRE_Int *marker, HYPRE_Int marker_val ); HYPRE_Int hypre_SeqVectorElmdivpyHost( hypre_Vector *x, hypre_Vector *b, hypre_Vector *y, HYPRE_Int *marker, HYPRE_Int marker_val ); -HYPRE_Int hypre_SeqVectorElmdivpyDevice( hypre_Vector *x, hypre_Vector *b, hypre_Vector *y, - HYPRE_Int *marker, HYPRE_Int marker_val ); - HYPRE_Int hypre_CSRMatrixSpMVDevice( HYPRE_Int trans, HYPRE_Complex alpha, hypre_CSRMatrix *A, hypre_Vector *x, HYPRE_Complex beta, hypre_Vector *y, HYPRE_Int fill ); @@ -376,3 +366,19 @@ hypre_GpuMatData* hypre_CSRMatrixGetGPUMatData(hypre_CSRMatrix *matrix); #endif HYPRE_Int hypre_CSRMatrixSpMVAnalysisDevice(hypre_CSRMatrix *matrix); + +/* vector_device.c */ +HYPRE_Int hypre_SeqVectorSetConstantValuesDevice ( hypre_Vector *v, HYPRE_Complex value ); +HYPRE_Int hypre_SeqVectorScaleDevice( HYPRE_Complex alpha, hypre_Vector *y ); +HYPRE_Int hypre_SeqVectorAxpyDevice ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); +HYPRE_Int hypre_SeqVectorAxpyzDevice ( HYPRE_Complex alpha, hypre_Vector *x, + HYPRE_Complex beta, hypre_Vector *y, + hypre_Vector *z ); +HYPRE_Int hypre_SeqVectorElmdivpyDevice( hypre_Vector *x, hypre_Vector *b, hypre_Vector *y, + HYPRE_Int *marker, HYPRE_Int marker_val ); +HYPRE_Real hypre_SeqVectorInnerProdDevice ( hypre_Vector *x, hypre_Vector *y ); +HYPRE_Complex hypre_SeqVectorSumEltsDevice ( hypre_Vector *vector ); +HYPRE_Int hypre_SeqVectorStridedCopyDevice( hypre_Vector *vector, + HYPRE_Int istride, HYPRE_Int ostride, + HYPRE_Int size, HYPRE_Complex *data ); +HYPRE_Int hypre_SeqVectorPrefetch(hypre_Vector *x, HYPRE_MemoryLocation memory_location); diff --git a/src/seq_mv/seq_mv.h b/src/seq_mv/seq_mv.h index eb54a87ba2..6b4305096f 100644 --- a/src/seq_mv/seq_mv.h +++ b/src/seq_mv/seq_mv.h @@ -546,28 +546,23 @@ hypre_Vector *hypre_SeqVectorRead ( char *file_name ); HYPRE_Int hypre_SeqVectorPrint ( hypre_Vector *vector, char *file_name ); HYPRE_Int hypre_SeqVectorSetConstantValues ( hypre_Vector *v, HYPRE_Complex value ); HYPRE_Int hypre_SeqVectorSetConstantValuesHost ( hypre_Vector *v, HYPRE_Complex value ); -HYPRE_Int hypre_SeqVectorSetConstantValuesDevice ( hypre_Vector *v, HYPRE_Complex value ); HYPRE_Int hypre_SeqVectorSetRandomValues ( hypre_Vector *v, HYPRE_Int seed ); HYPRE_Int hypre_SeqVectorCopy ( hypre_Vector *x, hypre_Vector *y ); +HYPRE_Int hypre_SeqVectorStridedCopy( hypre_Vector *x, HYPRE_Int istride, HYPRE_Int ostride, + HYPRE_Int size, HYPRE_Complex *data); hypre_Vector *hypre_SeqVectorCloneDeep ( hypre_Vector *x ); hypre_Vector *hypre_SeqVectorCloneDeep_v2( hypre_Vector *x, HYPRE_MemoryLocation memory_location ); hypre_Vector *hypre_SeqVectorCloneShallow ( hypre_Vector *x ); HYPRE_Int hypre_SeqVectorMigrate( hypre_Vector *x, HYPRE_MemoryLocation memory_location ); HYPRE_Int hypre_SeqVectorScale( HYPRE_Complex alpha, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorScaleHost( HYPRE_Complex alpha, hypre_Vector *y ); -HYPRE_Int hypre_SeqVectorScaleDevice( HYPRE_Complex alpha, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorAxpy ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorAxpyHost ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); -HYPRE_Int hypre_SeqVectorAxpyDevice ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorAxpyz ( HYPRE_Complex alpha, hypre_Vector *x, HYPRE_Complex beta, hypre_Vector *y, hypre_Vector *z ); -HYPRE_Int hypre_SeqVectorAxpyzDevice ( HYPRE_Complex alpha, hypre_Vector *x, - HYPRE_Complex beta, hypre_Vector *y, - hypre_Vector *z ); HYPRE_Real hypre_SeqVectorInnerProd ( hypre_Vector *x, hypre_Vector *y ); HYPRE_Real hypre_SeqVectorInnerProdHost ( hypre_Vector *x, hypre_Vector *y ); -HYPRE_Real hypre_SeqVectorInnerProdDevice ( hypre_Vector *x, hypre_Vector *y ); HYPRE_Int hypre_SeqVectorMassInnerProd(hypre_Vector *x, hypre_Vector **y, HYPRE_Int k, HYPRE_Int unroll, HYPRE_Real *result); HYPRE_Int hypre_SeqVectorMassInnerProd4(hypre_Vector *x, hypre_Vector **y, HYPRE_Int k, @@ -588,8 +583,6 @@ HYPRE_Int hypre_SeqVectorMassAxpy8(HYPRE_Complex *alpha, hypre_Vector **x, hypre HYPRE_Int k); HYPRE_Complex hypre_SeqVectorSumElts ( hypre_Vector *vector ); HYPRE_Complex hypre_SeqVectorSumEltsHost ( hypre_Vector *vector ); -HYPRE_Complex hypre_SeqVectorSumEltsDevice ( hypre_Vector *vector ); -HYPRE_Int hypre_SeqVectorPrefetch(hypre_Vector *x, HYPRE_MemoryLocation memory_location); //HYPRE_Int hypre_SeqVectorMax( HYPRE_Complex alpha, hypre_Vector *x, HYPRE_Complex beta, hypre_Vector *y ); HYPRE_Int hypreDevice_CSRSpAdd(HYPRE_Int ma, HYPRE_Int mb, HYPRE_Int n, HYPRE_Int nnzA, @@ -626,9 +619,6 @@ HYPRE_Int hypre_SeqVectorElmdivpyMarked( hypre_Vector *x, hypre_Vector *b, hypre HYPRE_Int *marker, HYPRE_Int marker_val ); HYPRE_Int hypre_SeqVectorElmdivpyHost( hypre_Vector *x, hypre_Vector *b, hypre_Vector *y, HYPRE_Int *marker, HYPRE_Int marker_val ); -HYPRE_Int hypre_SeqVectorElmdivpyDevice( hypre_Vector *x, hypre_Vector *b, hypre_Vector *y, - HYPRE_Int *marker, HYPRE_Int marker_val ); - HYPRE_Int hypre_CSRMatrixSpMVDevice( HYPRE_Int trans, HYPRE_Complex alpha, hypre_CSRMatrix *A, hypre_Vector *x, HYPRE_Complex beta, hypre_Vector *y, HYPRE_Int fill ); @@ -656,6 +646,22 @@ hypre_GpuMatData* hypre_CSRMatrixGetGPUMatData(hypre_CSRMatrix *matrix); HYPRE_Int hypre_CSRMatrixSpMVAnalysisDevice(hypre_CSRMatrix *matrix); +/* vector_device.c */ +HYPRE_Int hypre_SeqVectorSetConstantValuesDevice ( hypre_Vector *v, HYPRE_Complex value ); +HYPRE_Int hypre_SeqVectorScaleDevice( HYPRE_Complex alpha, hypre_Vector *y ); +HYPRE_Int hypre_SeqVectorAxpyDevice ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y ); +HYPRE_Int hypre_SeqVectorAxpyzDevice ( HYPRE_Complex alpha, hypre_Vector *x, + HYPRE_Complex beta, hypre_Vector *y, + hypre_Vector *z ); +HYPRE_Int hypre_SeqVectorElmdivpyDevice( hypre_Vector *x, hypre_Vector *b, hypre_Vector *y, + HYPRE_Int *marker, HYPRE_Int marker_val ); +HYPRE_Real hypre_SeqVectorInnerProdDevice ( hypre_Vector *x, hypre_Vector *y ); +HYPRE_Complex hypre_SeqVectorSumEltsDevice ( hypre_Vector *vector ); +HYPRE_Int hypre_SeqVectorStridedCopyDevice( hypre_Vector *vector, + HYPRE_Int istride, HYPRE_Int ostride, + HYPRE_Int size, HYPRE_Complex *data ); +HYPRE_Int hypre_SeqVectorPrefetch(hypre_Vector *x, HYPRE_MemoryLocation memory_location); + #ifdef __cplusplus } #endif diff --git a/src/seq_mv/vector.c b/src/seq_mv/vector.c index 6b568dafbc..d0a8f1700f 100644 --- a/src/seq_mv/vector.c +++ b/src/seq_mv/vector.c @@ -459,6 +459,67 @@ hypre_SeqVectorCopy( hypre_Vector *x, return hypre_error_flag; } +/*-------------------------------------------------------------------------- + * hypre_SeqVectorStridedCopy + * + * Perform strided copy from a data array to x->data. + * + * We assume that the data array lives in the same memory location as x->data + *--------------------------------------------------------------------------*/ + +HYPRE_Int +hypre_SeqVectorStridedCopy( hypre_Vector *x, + HYPRE_Int istride, + HYPRE_Int ostride, + HYPRE_Int size, + HYPRE_Complex *data) +{ + HYPRE_Int x_size = hypre_VectorSize(x); + HYPRE_Complex *x_data = hypre_VectorData(x); + + HYPRE_Int i; + + /* Sanity checks */ + if (istride < 1) + { + hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Input stride needs to be greater than zero!"); + return hypre_error_flag; + } + + if (ostride < 1) + { + hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Output stride needs to be greater than zero!"); + return hypre_error_flag; + } + + if (x_size < (size / istride) * ostride) + { + hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Not enough space in x!"); + return hypre_error_flag; + } + +#if defined(HYPRE_USING_GPU) + HYPRE_ExecutionPolicy exec = hypre_GetExecPolicy1(hypre_VectorMemoryLocation(x)); + + if (exec == HYPRE_EXEC_DEVICE) + { + hypre_SeqVectorStridedCopyDevice(x, istride, ostride, size, data); + } + else +#endif + { +#if defined(HYPRE_USING_OPENMP) + #pragma omp parallel for private(i) HYPRE_SMP_SCHEDULE +#endif + for (i = 0; i < size; i += istride) + { + x_data[(i / istride) * ostride] = data[i]; + } + } + + return hypre_error_flag; +} + /*-------------------------------------------------------------------------- * hypre_SeqVectorCloneDeep_v2 *--------------------------------------------------------------------------*/ diff --git a/src/seq_mv/vector_device.c b/src/seq_mv/vector_device.c index 7053e5de22..0ae7b82f4d 100644 --- a/src/seq_mv/vector_device.c +++ b/src/seq_mv/vector_device.c @@ -339,6 +339,36 @@ hypre_SeqVectorSumEltsDevice( hypre_Vector *vector ) return sum; } +/*-------------------------------------------------------------------------- + * hypre_SeqVectorStridedCopyDevice + *--------------------------------------------------------------------------*/ + +HYPRE_Int +hypre_SeqVectorStridedCopyDevice( hypre_Vector *vector, + HYPRE_Int istride, + HYPRE_Int ostride, + HYPRE_Int size, + HYPRE_Complex *data) +{ + HYPRE_Complex *v_data = hypre_VectorData(vector); + +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) + auto begin = thrust::make_counting_iterator(0); + auto last = thrust::make_counting_iterator(size / istride); + + HYPRE_THRUST_CALL( transform, begin, last, + thrust::make_permutation_iterator(v_data, + thrust::make_transform_iterator(begin, + hypreFunctor_IndexStrided(ostride))), + hypreFunctor_ArrayStridedAccess(istride, data) ); + +#elif defined(HYPRE_USING_DEVICE_OPENMP) || defined(HYPRE_USING_SYCL) + hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Not implemented!"); +#endif + + return hypre_error_flag; +} + /*-------------------------------------------------------------------------- * hypre_SeqVectorPrefetch *--------------------------------------------------------------------------*/ diff --git a/src/utilities/_hypre_utilities.hpp b/src/utilities/_hypre_utilities.hpp index e7c6fff836..2fe5fee1dc 100644 --- a/src/utilities/_hypre_utilities.hpp +++ b/src/utilities/_hypre_utilities.hpp @@ -45,6 +45,52 @@ struct hypreFunctor_DenseMatrixIdentity } }; +/*-------------------------------------------------------------------------- + * hypreFunctor_ArrayStridedAccess + * + * Functor for performing strided data access on a templated array. + * + * The stride interval "s_" is used to access every "s_"-th element + * from the source array "a_". + * + * It is templated to support various data types for the array. + *--------------------------------------------------------------------------*/ + +template +struct hypreFunctor_ArrayStridedAccess +{ + HYPRE_Int s_; + T *a_; + + hypreFunctor_ArrayStridedAccess(HYPRE_Int s, T *a) : s_(s), a_(a) {} + + __host__ __device__ T operator()(HYPRE_Int i) + { + return a_[i * s_]; + } +}; + +/*-------------------------------------------------------------------------- + * hypreFunctor_IndexStrided + * + * This functor multiplies a given index "i" by a specified stride "s_". + * + * It is templated to support various data types for the index and stride. + *--------------------------------------------------------------------------*/ + +template +struct hypreFunctor_IndexStrided +{ + T s_; + + hypreFunctor_IndexStrided(T s) : s_(s) {} + + __host__ __device__ T operator()(const T i) const + { + return i * s_; + } +}; + /*-------------------------------------------------------------------------- * hypreFunctor_IndexCycle *--------------------------------------------------------------------------*/ @@ -70,6 +116,49 @@ struct hypreFunctor_IndexCycle * SPDX-License-Identifier: (Apache-2.0 OR MIT) ******************************************************************************/ +#ifndef HYPRE_PREDICATES_H +#define HYPRE_PREDICATES_H + +/****************************************************************************** + * + * Header file defining predicates for thrust used throughout hypre + * + *****************************************************************************/ + +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) + +/*-------------------------------------------------------------------------- + * hyprePred_StridedAccess + * + * This struct defines a predicate for strided access in array-like data. + * + * It is used to determine if an element at a given index should be processed + * or not, based on a specified stride. The operator() returns true when the + * index is a multiple of the stride, indicating the element at that index + * is part of the strided subset. + *--------------------------------------------------------------------------*/ + +struct hyprePred_StridedAccess +{ + HYPRE_Int s_; + + hyprePred_StridedAccess(HYPRE_Int s) : s_(s) {} + + __host__ __device__ HYPRE_Int operator()(const HYPRE_Int i) const + { + return (!(i % s_)); + } +}; + +#endif /* if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) */ +#endif /* ifndef HYPRE_PREDICATES_H */ +/****************************************************************************** + * Copyright (c) 1998 Lawrence Livermore National Security, LLC and other + * HYPRE Project Developers. See the top-level COPYRIGHT file for details. + * + * SPDX-License-Identifier: (Apache-2.0 OR MIT) + ******************************************************************************/ + #ifndef DEVICE_ALLOCATOR_H #define DEVICE_ALLOCATOR_H diff --git a/src/utilities/functors.h b/src/utilities/functors.h index 959a592e3c..75cc8c954d 100644 --- a/src/utilities/functors.h +++ b/src/utilities/functors.h @@ -34,6 +34,52 @@ struct hypreFunctor_DenseMatrixIdentity } }; +/*-------------------------------------------------------------------------- + * hypreFunctor_ArrayStridedAccess + * + * Functor for performing strided data access on a templated array. + * + * The stride interval "s_" is used to access every "s_"-th element + * from the source array "a_". + * + * It is templated to support various data types for the array. + *--------------------------------------------------------------------------*/ + +template +struct hypreFunctor_ArrayStridedAccess +{ + HYPRE_Int s_; + T *a_; + + hypreFunctor_ArrayStridedAccess(HYPRE_Int s, T *a) : s_(s), a_(a) {} + + __host__ __device__ T operator()(HYPRE_Int i) + { + return a_[i * s_]; + } +}; + +/*-------------------------------------------------------------------------- + * hypreFunctor_IndexStrided + * + * This functor multiplies a given index "i" by a specified stride "s_". + * + * It is templated to support various data types for the index and stride. + *--------------------------------------------------------------------------*/ + +template +struct hypreFunctor_IndexStrided +{ + T s_; + + hypreFunctor_IndexStrided(T s) : s_(s) {} + + __host__ __device__ T operator()(const T i) const + { + return i * s_; + } +}; + /*-------------------------------------------------------------------------- * hypreFunctor_IndexCycle *--------------------------------------------------------------------------*/ diff --git a/src/utilities/headers b/src/utilities/headers index f8ecf7d7c3..f4fef2202b 100755 --- a/src/utilities/headers +++ b/src/utilities/headers @@ -97,6 +97,7 @@ extern "C++" { #=========================================================================== cat functors.h >> $INTERNAL_HEADER +cat predicates.h >> $INTERNAL_HEADER cat device_allocator.h >> $INTERNAL_HEADER cat device_utils.h >> $INTERNAL_HEADER cat device_reducer.h >> $INTERNAL_HEADER diff --git a/src/utilities/predicates.h b/src/utilities/predicates.h new file mode 100644 index 0000000000..0f01d5c9ab --- /dev/null +++ b/src/utilities/predicates.h @@ -0,0 +1,43 @@ +/****************************************************************************** + * Copyright (c) 1998 Lawrence Livermore National Security, LLC and other + * HYPRE Project Developers. See the top-level COPYRIGHT file for details. + * + * SPDX-License-Identifier: (Apache-2.0 OR MIT) + ******************************************************************************/ + +#ifndef HYPRE_PREDICATES_H +#define HYPRE_PREDICATES_H + +/****************************************************************************** + * + * Header file defining predicates for thrust used throughout hypre + * + *****************************************************************************/ + +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) + +/*-------------------------------------------------------------------------- + * hyprePred_StridedAccess + * + * This struct defines a predicate for strided access in array-like data. + * + * It is used to determine if an element at a given index should be processed + * or not, based on a specified stride. The operator() returns true when the + * index is a multiple of the stride, indicating the element at that index + * is part of the strided subset. + *--------------------------------------------------------------------------*/ + +struct hyprePred_StridedAccess +{ + HYPRE_Int s_; + + hyprePred_StridedAccess(HYPRE_Int s) : s_(s) {} + + __host__ __device__ HYPRE_Int operator()(const HYPRE_Int i) const + { + return (!(i % s_)); + } +}; + +#endif /* if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) */ +#endif /* ifndef HYPRE_PREDICATES_H */