Skip to content

Commit

Permalink
[Unity][Contrib] Workspace for cuBLAS backend
Browse files Browse the repository at this point in the history
This PR adds a 32MB workspace for cuBLAS backend, so that
functions like `cublasLtMatmul` can take the workspace as
input.

The workspace is managed under CuBlasThreadEntry so that
it will be allocated only once in each thread.
  • Loading branch information
MasterJH5574 committed Jan 17, 2024
1 parent a2a1b53 commit a102768
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 7 deletions.
20 changes: 17 additions & 3 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }

#if CUDART_VERSION >= 10010

void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, const DLTensor* B,
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
cublasLtEpilogue_t epilogue) {
void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue) {
ICHECK(TypeEqual(A->dtype, B->dtype));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
Expand Down Expand Up @@ -265,8 +266,21 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A,
auto B_data = static_cast<char*>(B->data) + B->byte_offset;
auto C_data = static_cast<char*>(C->data) + C->byte_offset;

cublasLtMatmulPreferenceSetAttribute(matmul_pref_desc, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(size_t));

cublasLtMatmulHeuristicResult_t heuristic_result = {};
int returned_result = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc,
matmul_pref_desc, 1, &heuristic_result,
&returned_result));
if (returned_result == 0) {
CHECK_CUBLAS_ERROR(CUBLAS_STATUS_NOT_SUPPORTED);
}

CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta,
C_data, C_desc, C_data, C_desc, nullptr, nullptr, 0, stream));
C_data, C_desc, C_data, C_desc, &heuristic_result.algo,
workspace_ptr, workspace_size, stream));

cublasLtMatmulDescDestroy(op_desc);
cublasLtMatrixLayoutDestroy(A_desc);
Expand Down
5 changes: 3 additions & 2 deletions src/runtime/contrib/cublas/cublas_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ class CublasJSONRuntime : public JSONRuntimeBase {

auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);

tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr, bias_ptr, out_ptr,
transa, transb, epilogue);
tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr,
b_ptr, bias_ptr, out_ptr, transa, transb,
entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue);
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/contrib/cublas/cublas_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,25 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
return retval;
}

CuBlasLtThreadEntry::CuBlasLtThreadEntry() { CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); }
CuBlasLtThreadEntry::CuBlasLtThreadEntry() {
CHECK_CUBLAS_ERROR(cublasLtCreate(&handle));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&matmul_pref_desc));
CUDA_CALL(cudaMalloc(&workspace_ptr, workspace_size));
}

CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
if (handle) {
cublasLtDestroy(handle);
handle = nullptr;
}
if (matmul_pref_desc) {
cublasLtMatmulPreferenceDestroy(matmul_pref_desc);
matmul_pref_desc = nullptr;
}
if (workspace_ptr != nullptr) {
cudaFree(workspace_ptr);
workspace_ptr = nullptr;
}
}

typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;
Expand Down
9 changes: 8 additions & 1 deletion src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ struct CuBlasThreadEntry {
struct CuBlasLtThreadEntry {
CuBlasLtThreadEntry();
~CuBlasLtThreadEntry();

cublasLtHandle_t handle{nullptr};
cublasLtMatmulPreference_t matmul_pref_desc{nullptr};
void* workspace_ptr{nullptr};
constexpr const static size_t workspace_size = 32768;

static CuBlasLtThreadEntry* ThreadLocal();
}; // CuBlasLtThreadEntry

Expand Down Expand Up @@ -113,8 +118,10 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
}

/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, const DLTensor* B,
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
void* workspace_ptr, size_t workspace_size,
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT);

} // namespace contrib
Expand Down

0 comments on commit a102768

Please sign in to comment.