diff --git a/cpp/include/raft/core/cusparse_macros.hpp b/cpp/include/raft/core/cusparse_macros.hpp index 9058f4847d..5a1968b529 100644 --- a/cpp/include/raft/core/cusparse_macros.hpp +++ b/cpp/include/raft/core/cusparse_macros.hpp @@ -34,7 +34,8 @@ // // (i.e., before including this header) // -#define CUDA_VER_10_1_UP (CUDART_VERSION >= 10100) +#define CUDA_VER_10_1_UP (CUDART_VERSION >= 10010) +#define CUDA_VER_12_4_UP (CUDART_VERSION >= 12040) namespace raft { @@ -59,7 +60,7 @@ namespace detail { inline const char* cusparse_error_to_string(cusparseStatus_t err) { -#if defined(CUDART_VERSION) && CUDART_VERSION >= 10100 +#if defined(CUDART_VERSION) && CUDART_VERSION >= 10010 return cusparseGetErrorString(err); #else // CUDART_VERSION switch (err) { diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index 08efbb7106..ae552cc687 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -393,6 +393,34 @@ inline cusparseStatus_t cusparsespmv(cusparseHandle_t handle, CUSPARSE_CHECK(cusparseSetStream(handle, stream)); return cusparseSpMV(handle, opA, alpha, matA, vecX, beta, vecY, CUDA_R_64F, alg, externalBuffer); } +// cusparseSpMV_preprocess is only available starting CUDA 12.4 +#if CUDA_VER_12_4_UP +template < + typename T, + typename std::enable_if_t || std::is_same_v>* = nullptr> +cusparseStatus_t cusparsespmv_preprocess(cusparseHandle_t handle, + cusparseOperation_t opA, + const T* alpha, + const cusparseSpMatDescr_t matA, + const cusparseDnVecDescr_t vecX, + const T* beta, + const cusparseDnVecDescr_t vecY, + cusparseSpMVAlg_t alg, + T* externalBuffer, + cudaStream_t stream) +{ + auto constexpr float_type = []() constexpr { + if constexpr (std::is_same_v) { + return CUDA_R_32F; + } else if constexpr (std::is_same_v) { + return CUDA_R_64F; + } + }(); + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSpMV_preprocess( + handle, opA, alpha, matA, vecX, beta, vecY, float_type, alg, externalBuffer); +} +#endif /** @} */ #else /**