Skip to content

Commit

Permalink
hipsparse -> rocsparse
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Dec 8, 2024
1 parent d05a3b4 commit 1c45876
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/dependencies/dependencies_hip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ sudo apt-get install -y --no-install-recommends \
rocrand-dev \
rocfft-dev \
rocprim-dev \
rocsparse-dev \
hipsparse-dev
rocsparse-dev

# hiprand-dev is a new package that does not exist in old versions
sudo apt-get install -y --no-install-recommends hiprand-dev || true
Expand Down
63 changes: 38 additions & 25 deletions Src/LinearSolvers/AMReX_SpMV.H
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#if defined(AMREX_USE_CUDA)
# include <cusparse.h>
#elif defined(AMREX_USE_HIP)
# include <hipsparse/hipsparse.h>
# include <rocsparse/rocsparse.h>
#elif defined(AMREX_USE_DPCPP)
# include <oneapi/mkl/spblas.hpp>
#endif
Expand All @@ -19,6 +19,8 @@ namespace amrex {
template <typename T>
void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
{
// xxxxx TODOL We might want to cache the cusparse and rocsparse handles

// xxxxx TODO: let's assume it's square matrix for now.
AMREX_ALWAYS_ASSERT(x.partition() == y.partition() &&
x.partition() == A.partition());
Expand Down Expand Up @@ -90,53 +92,64 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)

#elif defined(AMREX_USE_HIP)

hipsparseHandle_t handle;
hipsparseCreate(&handle);
hipsparseSetStream(handle, Gpu::gpuStream());
rocsparse_handle handle;
rocsparse_create_handle(&handle);
rocsparse_set_stream(handle, Gpu::gpuStream());

hipDataType data_type;
rocsparse_datatype data_type;
if constexpr (std::is_same_v<T,float>) {
data_type = HIP_R_32F;
data_type = rocsparse_datatype_f32_r;
} else if constexpr (std::is_same_v<T,double>) {
data_type = HIP_R_64F;
data_type = rocsparse_datatype_f64_r;
} else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
data_type = HIP_C_32F;
data_type = rocsparse_datatype_f32_c;
} else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
data_type = HIP_C_64F;
data_type = rocsparse_datatype_f64_c;
} else {
amrex::Abort("SpMV: unsupported data type");
}

hipsparseIndexType_t index_type = HIPSPARSE_INDEX_64I;
rocsparse_indextype index_type = rocsparse_indextype_i64;

hipsparseSpMatDescr_t mat_descr;
hipsparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat,
index_type, index_type, HIPSPARSE_INDEX_BASE_ZERO, data_type);
rocsparse_spmat_descr mat_descr;
rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col,
(void*)mat, index_type, index_type,
rocsparse_index_base_zero, data_type);

hipsparseDnVecDescr_t x_descr;
hipsparseCreateDnVec(&x_descr, ncols, (void*)px, data_type);
rocsparse_dnvec_descr x_descr;
rocsparse_create_dnvec_descr(&x_descr, ncols, (void*)px, data_type);

hipsparseDnVecDescr_t y_descr;
hipsparseCreateDnVec(&y_descr, nrows, (void*)py, data_type);
rocsparse_dnvec_descr y_descr;
rocsparse_create_dnvec_descr(&y_descr, nrows, (void*)py, data_type);

T alpha = T(1.0);
T beta = T(0.0);

std::size_t buffer_size;
hipsparseSpMV_bufferSize(handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, HIPSPARSE_SPMV_ALG_DEFAULT, &buffer_size);
rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, rocsparse_spmv_alg_default,
// rocsparse_spmv_stage_buffer_size,
&buffer_size, nullptr);

void* pbuffer = (void*)The_Arena()->alloc(buffer_size);

hipsparseSpMV(handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, HIPSPARSE_SPMV_ALG_DEFAULT, pbuffer);
#if 0
rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, rocsparse_spmv_alg_default,
rocsparse_spmv_stage_preprocess, &buffer_size, pbuffer);
#endif

rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, rocsparse_spmv_alg_default,
// rocsparse_spmv_stage_compute,
&buffer_size, pbuffer);

Gpu::streamSynchronize();

hipsparseDestroySpMat(mat_descr);
hipsparseDestroyDnVec(x_descr);
hipsparseDestroyDnVec(y_descr);
hipsparseDestroy(handle);
rocsparse_destroy_spmat_descr(mat_descr);
rocsparse_destroy_dnvec_descr(x_descr);
rocsparse_destroy_dnvec_descr(y_descr);
rocsparse_destroy_handle(handle);
The_Arena()->free(pbuffer);

#elif defined(AMREX_USE_DPCPP)
Expand Down
3 changes: 1 addition & 2 deletions Tools/CMake/AMReXParallelBackends.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ if (AMReX_HIP)
find_package(hiprand REQUIRED CONFIG)
if (AMReX_LINEAR_SOLVERS)
find_package(rocsparse REQUIRED CONFIG)
find_package(hipsparse REQUIRED CONFIG)
endif()

if(AMReX_ROCTX)
Expand All @@ -319,7 +318,7 @@ if (AMReX_HIP)
endforeach()
if (AMReX_LINEAR_SOLVERS)
foreach(D IN LISTS AMReX_SPACEDIM)
target_link_libraries(amrex_${D}d PUBLIC hip::hipsparse roc::rocsparse)
target_link_libraries(amrex_${D}d PUBLIC roc::rocsparse)
endforeach()
endif()

Expand Down

0 comments on commit 1c45876

Please sign in to comment.