From 8edf01f0b74178ea0541768dca0e3d6c6e2fd59e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 16 Jan 2024 20:34:12 -0500 Subject: [PATCH] [Contrib] Workspace for cuBLAS backend This PR adds a 32MB workspace for cuBLAS backend, so that functions like `cublasLtMatmul` can take the workspace as input. The workspace is managed under CuBlasThreadEntry so that it will be allocated only once in each thread. --- src/runtime/contrib/cublas/cublas.cc | 20 ++++++++++++++++--- .../contrib/cublas/cublas_json_runtime.cc | 5 +++-- src/runtime/contrib/cublas/cublas_utils.cc | 14 ++++++++++++- src/runtime/contrib/cublas/cublas_utils.h | 11 +++++++++- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 10db3b1c50ab..7a867f4bae18 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -135,9 +135,10 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 -void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, const DLTensor* B, +void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, + cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, - cublasLtEpilogue_t epilogue) { + void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue) { ICHECK(TypeEqual(A->dtype, B->dtype)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; @@ -265,8 +266,21 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, auto B_data = static_cast(B->data) + B->byte_offset; auto C_data = static_cast(C->data) + C->byte_offset; + cublasLtMatmulPreferenceSetAttribute(matmul_pref_desc, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(size_t)); + + cublasLtMatmulHeuristicResult_t heuristic_result = {}; + int returned_result = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc, + matmul_pref_desc, 1, &heuristic_result, + &returned_result)); + if (returned_result == 0) { + CHECK_CUBLAS_ERROR(CUBLAS_STATUS_NOT_SUPPORTED); + } + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta, - C_data, C_desc, C_data, C_desc, nullptr, nullptr, 0, stream)); + C_data, C_desc, C_data, C_desc, &heuristic_result.algo, + workspace_ptr, workspace_size, stream)); cublasLtMatmulDescDestroy(op_desc); cublasLtMatrixLayoutDestroy(A_desc); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 23e35d2f7188..1a072a92eb8b 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -129,8 +129,9 @@ class CublasJSONRuntime : public JSONRuntimeBase { auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT); - tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr, bias_ptr, out_ptr, - transa, transb, epilogue); + tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr, + b_ptr, bias_ptr, out_ptr, transa, transb, + entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue); } } } diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 5cd07cf71dc6..5844f802fd84 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -48,13 +48,25 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { return retval; } -CuBlasLtThreadEntry::CuBlasLtThreadEntry() { CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); } +CuBlasLtThreadEntry::CuBlasLtThreadEntry() { + CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&matmul_pref_desc)); + CUDA_CALL(cudaMalloc(&workspace_ptr, workspace_size)); +} CuBlasLtThreadEntry::~CuBlasLtThreadEntry() { if (handle) { cublasLtDestroy(handle); handle = nullptr; } + if (matmul_pref_desc) { + cublasLtMatmulPreferenceDestroy(matmul_pref_desc); + matmul_pref_desc = nullptr; + } + if (workspace_ptr != nullptr) { + cudaFree(workspace_ptr); + workspace_ptr = nullptr; + } } typedef dmlc::ThreadLocalStore CuBlasLtThreadStore; diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 82c77cfb5ef2..5c5cb6920860 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -80,7 +80,14 @@ struct CuBlasThreadEntry { struct CuBlasLtThreadEntry { CuBlasLtThreadEntry(); ~CuBlasLtThreadEntry(); + cublasLtHandle_t handle{nullptr}; + cublasLtMatmulPreference_t matmul_pref_desc{nullptr}; + void* workspace_ptr{nullptr}; + // 32MB workspace as suggested by NVIDIA + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace. + static constexpr const size_t workspace_size = 33554432; + static CuBlasLtThreadEntry* ThreadLocal(); }; // CuBlasLtThreadEntry @@ -113,8 +120,10 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { } /*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */ -void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, const DLTensor* B, +void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, + cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, + void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT); } // namespace contrib