diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fba5355f077..aa2a385683d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,6 +107,7 @@ tvm_option(USE_THRUST "Build with Thrust" OFF) tvm_option(USE_CURAND "Build with cuRAND" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) +tvm_option(USE_HIPBLAS "Build with ROCM:HIPBLAS" OFF) tvm_option(USE_SORT "Build with sort support" ON) tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_LIBTORCH "Build with libtorch support" OFF) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index c4637a0c17f7..da9bc3e1c9d3 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -116,6 +116,7 @@ function(add_lib_info src_file) TVM_INFO_TVM_DEBUG_WITH_ABI_CHANGE="${TVM_DEBUG_WITH_ABI_CHANGE}" TVM_INFO_TVM_LOG_BEFORE_THROW="${TVM_LOG_BEFORE_THROW}" TVM_INFO_USE_ROCBLAS="${USE_ROCBLAS}" + TVM_INFO_USE_HIPBLAS="${USE_HIPBLAS}" TVM_INFO_USE_ROCM="${USE_ROCM}" TVM_INFO_USE_RCCL="${USE_RCCL}" TVM_INFO_USE_RPC="${USE_RPC}" diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index 02c4c739934a..4d0f76d6871f 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -53,6 +53,18 @@ if(USE_ROCM) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY}) endif(USE_ROCBLAS) + if(USE_HIPBLAS) + message(STATUS "Build with HIPBLAS support") + tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRC src/relax/backend/contrib/hipblas/*.cc) + list(APPEND COMPILER_SRCS ${HIPBLAS_CONTRIB_SRC}) + tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRCS src/runtime/contrib/hipblas/*.cc) + list(APPEND RUNTIME_SRCS ${HIPBLAS_CONTRIB_SRCS}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLAS_LIBRARY}) + if(NOT ROCM_HIPBLASLT_LIBRARY STREQUAL "ROCM_HIPBLASLT_LIBRARY-NOTFOUND") + list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLASLT_LIBRARY}) + endif() + endif(USE_HIPBLAS) + if(USE_THRUST) message(STATUS "Build with rocThrust support") # We need to override CXX to hipcc. This is required by rocthrust diff --git a/cmake/utils/FindROCM.cmake b/cmake/utils/FindROCM.cmake index 4d895ff89d13..6f54c179ee76 100644 --- a/cmake/utils/FindROCM.cmake +++ b/cmake/utils/FindROCM.cmake @@ -55,6 +55,8 @@ macro(find_rocm use_rocm) endif() find_library(ROCM_MIOPEN_LIBRARY MIOpen ${__rocm_sdk}/lib) find_library(ROCM_ROCBLAS_LIBRARY rocblas ${__rocm_sdk}/lib) + find_library(ROCM_HIPBLAS_LIBRARY hipblas ${__rocm_sdk}/lib) + find_library(ROCM_HIPBLASLT_LIBRARY hipblaslt ${__rocm_sdk}/lib) find_library(ROCM_HSA_LIBRARY hsa-runtime64 ${__rocm_sdk}/lib) if(ROCM_HIPHCC_LIBRARY) @@ -66,5 +68,7 @@ macro(find_rocm use_rocm) message(STATUS "Found ROCM_HIPHCC_LIBRARY=" ${ROCM_HIPHCC_LIBRARY}) message(STATUS "Found ROCM_MIOPEN_LIBRARY=" ${ROCM_MIOPEN_LIBRARY}) message(STATUS "Found ROCM_ROCBLAS_LIBRARY=" ${ROCM_ROCBLAS_LIBRARY}) + message(STATUS "Found ROCM_HIPBLAS_LIBRARY=" ${ROCM_HIPBLAS_LIBRARY}) + message(STATUS "Found ROCM_HIPBLASLT_LIBRARY=" ${ROCM_HIPBLASLT_LIBRARY}) endif(ROCM_FOUND) endmacro(find_rocm) diff --git a/python/tvm/contrib/hipblas.py b/python/tvm/contrib/hipblas.py new file mode 100644 index 000000000000..f1e46a2caab1 --- /dev/null +++ b/python/tvm/contrib/hipblas.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External function interface to hipBLAS libraries.""" +import tvm +from tvm import te + + +def matmul(lhs, rhs, transa=False, transb=False, dtype=None): + """Create an extern op that compute matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + n = lhs.shape[1] if transa else lhs.shape[0] + m = rhs.shape[0] if transb else rhs.shape[1] + dtype = dtype if dtype is not None else lhs.dtype + return te.extern( + (n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.hipblas.matmul", ins[0], ins[1], outs[0], transa, transb + ), + dtype=dtype, + name="matmul_hipblas", + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None): + """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + dtype = dtype if dtype is not None else lhs.dtype + return te.extern( + (b, n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.hipblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb + ), + dtype=dtype, + name="batch_matmul_hipblas", + ) diff --git a/python/tvm/relax/backend/contrib/hipblas.py b/python/tvm/relax/backend/contrib/hipblas.py new file mode 100644 index 000000000000..c0accc1473e1 --- /dev/null +++ b/python/tvm/relax/backend/contrib/hipblas.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for hipblas backend""" +import operator +from functools import reduce + +import tvm +from tvm.relax import transform +from tvm.relax.transform import PatternCheckContext + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import make_matmul_pattern +from ..utils import has_leaking_intermediate_variables + + +def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint: disable=unused-argument + """Check if dtypes in the given workload are supported by hipblas BYOC.""" + if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8' + # return out_dtype != "e5m2_float8" + return False + return (lhs_dtype == "float16" and rhs_dtype == "float16") or ( + lhs_dtype == "int8" and rhs_dtype == "int8" + ) + + +def _check_matmul(context: PatternCheckContext) -> bool: + if has_leaking_intermediate_variables(context): + return False + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + matmul_call = context.annotated_expr["root"] + + lhs_dtype = lhs.struct_info.dtype + rhs_dtype = rhs.struct_info.dtype + out_dtype = matmul_call.struct_info.dtype + if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): + return False + + lhs_shape = lhs.struct_info.shape.values + rhs_shape = rhs.struct_info.shape.values + + if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)): + # Reduction axis must be constant + return False + + if lhs_dtype == "int8" and rhs_dtype == "int8": + return False + elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + return False + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + + if "bias" in context.annotated_expr: + if lhs_dtype == "int8" and rhs_dtype == "int8": + # Non-default epilogue not supported for IGEMM + return False + bias = context.annotated_expr["bias"] + bias_shape = bias.struct_info.shape.values + bias_batches = reduce(operator.mul, bias_shape[:-1], 1) + if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1: + # hipblas only supports bias vector + return False + + # hipblasLt does not seem to support batched GEMM with one of matrices having + # one batch (with batch_stride 0). So for batched GEMM, the two batch counts + # must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by + # flattening all batch axes into the M axis. + return ( + isinstance(lhs_batches, tvm.tir.Var) + or isinstance(rhs_batches, tvm.tir.Var) + or (int(lhs_batches) == int(rhs_batches)) + or (lhs_batches >= 1 and rhs_batches == 1) + ) + + +register_patterns( + [ + ( + "hipblas.matmul", + *make_matmul_pattern( + with_bias=False, + ), + _check_matmul, + ), + ( + "hipblas.matmul_bias", + *make_matmul_pattern( + with_bias=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + _check_matmul, + ), + ( + "hipblas.matmul_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed", + *make_matmul_pattern( + with_bias=False, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed_bias", + *make_matmul_pattern( + with_bias=True, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + transposed_rhs=True, + ), + _check_matmul, + ), + ] +) + + +def partition_for_hipblas(mod): + """ + Partition the input module into hipblas-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + offloaded to the hipblas backend. + """ + + patterns = get_patterns_with_prefix("hipblas") + return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 64eaccb410c8..8227530f7ab7 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -949,6 +949,9 @@ def _multi_gpu_exists(): parent_features="rocm", ) +# Mark a test as requiring the hipBLAS library. +requires_hipblas = Feature("hipblas", "hipBLAS", cmake_flag="USE_HIPBLAS", parent_features="rocm") + # Mark a test as requiring the metal runtime requires_metal = Feature( "metal", diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc new file mode 100644 index 000000000000..7de5c50a614d --- /dev/null +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/hipblas/codegen.cc + * \brief Implementation of the HIPBLAS JSON serializer. + */ +#include + +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class HipblasJSONSerializer : public JSONSerializer { + public: + HipblasJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + NodeEntries inputs_tmp; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); + } + + ICHECK(inputs_tmp.size() <= 3); + NodeEntries inputs(inputs_tmp.size()); + + auto arg_idx = backend::ExtractArgIdx(composite_name, fn); + inputs[0] = inputs_tmp[arg_idx["lhs"]->value]; + inputs[1] = inputs_tmp[arg_idx["rhs"]->value]; + if (inputs_tmp.size() == 3) { + inputs[2] = inputs_tmp[arg_idx["bias"]->value]; + } + + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +Array HipblasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; + + for (const auto& func : functions) { + HipblasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.HipblasJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find HIPBLAS runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.hipblas").set_body_typed(HipblasCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc new file mode 100644 index 000000000000..c135a2855d89 --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external hipblas library call. + */ +#include +#include +#include + +#include "../../3rdparty/compiler-rt/builtin_fp16.h" +#include "../cblas/gemm_common.h" +#include "hipblas_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; +inline hipblasOperation_t HIPBLASBooleanToTranspose(bool item) { + return item ? HIPBLAS_OP_T : HIPBLAS_OP_N; +} + +struct HipblasHgemmOp { + typedef hipblasHalf TDatatype; + hipblasHandle_t handle; + explicit HipblasHgemmOp(hipblasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, hipblasHalf alpha, hipblasHalf* A, int lda, + hipblasHalf* B, int ldb, hipblasHalf beta, hipblasHalf* C, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasHgemm(handle, HIPBLASBooleanToTranspose(ta), + HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, lda, B, ldb, + &beta, C, ldc)); + } +}; + +struct HipblasSgemmOp { + typedef float TDatatype; + hipblasHandle_t handle; + explicit HipblasSgemmOp(hipblasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasSgemm(handle, HIPBLASBooleanToTranspose(ta), + HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, lda, B, ldb, + &beta, C, ldc)); + } +}; + +struct HipblasDgemmOp { + typedef double TDatatype; + hipblasHandle_t handle; + explicit HipblasDgemmOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasDgemm(handle, HIPBLASBooleanToTranspose(ta), + HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, lda, B, ldb, + &beta, C, ldc)); + } +}; + +struct HipblasHgemmBatchOp { + typedef hipblasHalf TDatatype; + hipblasHandle_t handle; + explicit HipblasHgemmBatchOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, hipblasHalf alpha, + hipblasHalf* A, int a_stride, int lda, hipblasHalf* B, int b_stride, int ldb, + hipblasHalf beta, hipblasHalf* C, int c_stride, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasHgemmStridedBatched( + handle, HIPBLASBooleanToTranspose(ta), HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, + lda, a_stride, B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); + } +}; + +struct HipblasSgemmBatchOp { + typedef float TDatatype; + hipblasHandle_t handle; + explicit HipblasSgemmBatchOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasSgemmStridedBatched( + handle, HIPBLASBooleanToTranspose(ta), HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, + lda, a_stride, B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); + } +}; + +struct HipblasDgemmBatchOp { + typedef double TDatatype; + hipblasHandle_t handle; + explicit HipblasDgemmBatchOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasDgemmStridedBatched( + handle, HIPBLASBooleanToTranspose(ta), HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, + lda, a_stride, B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); + } +}; + +// Check supported mix-precision computation type and return computeType +bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) { + if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { + return TypeMatch(in_dtype, kDLInt, 8); + } else if (TypeMatch(out_dtype, kDLFloat, 32)) { + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); + } else { + return false; + } +} + +void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, + hipblasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, + const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, + bool transb, void* workspace_ptr, size_t workspace_size, + hipblasLtEpilogue_t epilogue) { + ICHECK(TypeEqual(A->dtype, B->dtype)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + auto compute_type = HIPBLAS_COMPUTE_32F; + auto scale_type = HIP_R_32F; + hipDataType ab_type = HIP_R_32F; + hipDataType c_type = HIP_R_32F; + float one_fp32 = 1.0; + float zero_fp32 = 0.0; + int32_t one_i32 = 1; + int32_t zero_i32 = 0; + void* alpha = &one_fp32; + void* beta = &zero_fp32; + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + ab_type = HIP_R_16F; + } else if (TypeMatch(A->dtype, kDLInt, 8)) { + ab_type = HIP_R_8I; + } + + if (TypeMatch(C->dtype, kDLFloat, 16)) { + c_type = HIP_R_16F; + } else if (TypeMatch(C->dtype, kDLInt, 32)) { + c_type = HIP_R_32I; + compute_type = HIPBLAS_COMPUTE_32I; + scale_type = HIP_R_32I; + alpha = &one_i32; + beta = &zero_i32; + } + + hipblasLtMatmulDesc_t op_desc; + hipblasOperation_t op_transa = HIPBLASBooleanToTranspose(transa); + hipblasOperation_t op_transb = HIPBLASBooleanToTranspose(transb); + + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_TRANSA, + &op_transb, sizeof(op_transb))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_TRANSB, + &op_transa, sizeof(op_transa))); + + if (bias != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias->data, sizeof(float*))); + } + + if (epilogue != HIPBLASLT_EPILOGUE_DEFAULT) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue))); + } + + int batch_offset_A = A->ndim - 2; + int batch_offset_B = B->ndim - 2; + + int M = ColumnCount(B, transb, batch_offset_B); + int N = RowCount(A, transa, batch_offset_A); + int K = ColumnCount(A, transa, batch_offset_A); + bool use_batched_gemm = A->ndim > 2 || B->ndim > 2; + + // If A is batched but B is not, flatten all non-reduction axes of A to use the regular GEMM. + // This trick is only applicable if batch axes and the other spatial axis (M or N) are + // adjacent in both the input and the output matrix. In particular, if A is of shape (M, K) + // and B matrix is of shape (Batch, N, K) with transb = true, the output shape + // is (Batch, M, N). Since the Batch and the N axes are not adjacent in the output, we cannot + // use the regular GEMM if only B is batched. + if (A->ndim > 2 && B->ndim == 2 && transa == false) { + N = 1; + for (int i = 0; i < A->ndim - 1; ++i) { + N *= A->shape[i]; + } + use_batched_gemm = false; + } + + int lda = transb ? K : M; + int ldb = transa ? N : K; + int ldc = M; + + hipblasLtMatrixLayout_t A_desc, B_desc, C_desc; + CHECK_HIPBLAS_ERROR( + hipblasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc)); + + if (use_batched_gemm) { + auto get_batch_count = [](int64_t* shape, int batch_offset) { + int64_t count = 1; + for (int i = 0; i < batch_offset; ++i) { + count *= shape[i]; + } + return count; + }; + auto set_batch = [](hipblasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_desc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); + CHECK_HIPBLAS_ERROR( + hipblasLtMatrixLayoutSetAttribute(mat_desc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride))); + }; + + int batch_count_A = get_batch_count(A->shape, batch_offset_A); + int batch_count_B = get_batch_count(B->shape, batch_offset_B); + int batch_count_C = get_batch_count(C->shape, C->ndim - 2); + int64_t batch_stride_A = M * K; + int64_t batch_stride_B = K * N; + int64_t batch_stride_C = M * N; + + // hipBLASLt does not seem to support batched GEMM with one of matrices having + // one batch (with batch_stride 0). + ICHECK_EQ(batch_count_A, batch_count_B); + + set_batch(A_desc, batch_count_A, batch_stride_A); + set_batch(B_desc, batch_count_B, batch_stride_B); + set_batch(C_desc, batch_count_C, batch_stride_C); + } + + auto A_data = static_cast(A->data) + A->byte_offset; + auto B_data = static_cast(B->data) + B->byte_offset; + auto C_data = static_cast(C->data) + C->byte_offset; + + hipblasLtMatmulPreferenceSetAttribute(matmul_pref_desc, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(size_t)); + + hipblasLtMatmulHeuristicResult_t heuristic_result = {}; + int returned_result = 0; + CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc, + matmul_pref_desc, 1, &heuristic_result, + &returned_result)); + if (returned_result == 0) { + CHECK_HIPBLAS_ERROR(HIPBLAS_STATUS_NOT_SUPPORTED); + } + + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta, + C_data, C_desc, C_data, C_desc, &heuristic_result.algo, + workspace_ptr, workspace_size, stream)); + + hipblasLtMatmulDescDestroy(op_desc); + hipblasLtMatrixLayoutDestroy(A_desc); + hipblasLtMatrixLayoutDestroy(B_desc); + hipblasLtMatrixLayoutDestroy(C_desc); +} + +inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + + ICHECK_EQ(ElementStride(A), 1); + ICHECK_EQ(ElementStride(B), 1); + ICHECK_EQ(ElementStride(C), 1); + + ICHECK(TypeEqual(A->dtype, B->dtype)); + + // C can never be transposed. + ICHECK(!IsInPlaceTransposed(C)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + + hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); + hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype); + hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT; + void *alpha_ptr = nullptr, *beta_ptr = nullptr; + auto alpha_int = static_cast(alpha); + auto beta_int = static_cast(beta); + auto alpha_float = static_cast(alpha); + auto beta_float = static_cast(beta); + if (C->dtype.code == kDLInt) { + alpha_ptr = &alpha_int; + beta_ptr = &beta_int; + } else if (C->dtype.code == kDLFloat) { + alpha_ptr = &alpha_float; + beta_ptr = &beta_float; + } + + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + CHECK_HIPBLAS_ERROR( + hipblasGemmEx(hdl, HIPBLASBooleanToTranspose(transb), HIPBLASBooleanToTranspose(transa), + ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), alpha_ptr, + B_data, hip_in_type, ColumnStride(B), A_data, hip_in_type, ColumnStride(A), + beta_ptr, C_data, hip_out_type, ColumnStride(C), hip_out_type, algo)); +} + +inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + ICHECK_EQ(A->ndim, 3); + ICHECK_EQ(B->ndim, 3); + ICHECK_EQ(C->ndim, 3); + + int batch_size = BatchCount3D(C); + ICHECK_EQ(ElementStride3D(A), 1); + ICHECK_EQ(ElementStride3D(B), 1); + ICHECK_EQ(ElementStride3D(C), 1); + + ICHECK(TypeEqual(A->dtype, B->dtype)); + + // C can never be transposed. + ICHECK(!IsInPlaceTransposed3D(C)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed3D(A) ? !transa : transa; + transb = IsInPlaceTransposed3D(B) ? !transb : transb; + + ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; + ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } + + hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); + hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype); + hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT; + void *alpha_ptr = nullptr, *beta_ptr = nullptr; + auto alpha_int = static_cast(alpha); + auto beta_int = static_cast(beta); + auto alpha_float = static_cast(alpha); + auto beta_float = static_cast(beta); + if (C->dtype.code == kDLInt) { + alpha_ptr = &alpha_int; + beta_ptr = &beta_int; + } else if (C->dtype.code == kDLFloat) { + alpha_ptr = &alpha_float; + beta_ptr = &beta_float; + } + + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + CHECK_HIPBLAS_ERROR(hipblasGemmStridedBatchedEx( + hdl, HIPBLASBooleanToTranspose(transb), HIPBLASBooleanToTranspose(transa), + ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, + hip_in_type, ColumnStride3D(B), B_stride, A_data, hip_in_type, ColumnStride3D(A), A_stride, + beta_ptr, C_data, hip_out_type, ColumnStride3D(C), C_stride, batch_size, hip_out_type, algo)); +} + +// matrix multiplication for row major +TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; + + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + CallGemm(args, ret, HipblasHgemmOp(entry_ptr->handle)); + } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallGemm(args, ret, HipblasSgemmOp(entry_ptr->handle)); + } else { + CallGemm(args, ret, HipblasDgemmOp(entry_ptr->handle)); + } + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; + + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + CallBatchGemm(args, ret, HipblasHgemmBatchOp(entry_ptr->handle)); + } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, HipblasSgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemm(args, ret, HipblasDgemmBatchOp(entry_ptr->handle)); + } + } else { + CallBatchGemmEx(args, ret, entry_ptr->handle); + } + }); + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc new file mode 100644 index 000000000000..a6e7949e4559 --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/hipblas/hipblas_json_runtime.cc + * \brief A simple JSON runtime for HIPBLAS. + */ + +#include +#include + +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" +#include "hipblas_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { +using namespace tvm::runtime; +using namespace tvm::runtime::json; +class HipblasJSONRuntime : public JSONRuntimeBase { + public: + HipblasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + void Init(const Array& consts) override {} + + PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime + // can be used by multiple GPUs running on different threads, we avoid using that function + // and directly call hipBLAS on the inputs from TVMArgs. + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + this->Run(args); + }); + } else { + return JSONRuntimeBase::GetFunction(name, sptr_to_self); + } + } + + const char* type_key() const override { return "hipblas_json"; } // May be overridden + + void Run(TVMArgs args) { + auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(); + + auto func = tvm::runtime::Registry::Get("runtime.get_rocm_stream"); + ICHECK(func != nullptr); + hipStream_t stream = static_cast((*func)().operator void*()); + + std::vector dl_tensors(NumEntries()); + + for (size_t i = 0; i < static_cast(args.size()); i++) { + auto eid = i < input_var_eid_.size() ? input_var_eid_[i] + : EntryID(outputs_[i - input_var_eid_.size()]); + ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor as inputs"; + + const DLTensor* arg; + if (args[i].IsObjectRef()) { + NDArray arr = args[i]; + arg = arr.operator->(); + } else { + arg = args[i].operator DLTensor*(); + } + + dl_tensors[eid] = arg; + } + + auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { + ICHECK_LT(idx, node.GetInputs().size()); + auto eid = EntryID(node.GetInputs()[idx]); + ICHECK(eid < dl_tensors.size()); + return dl_tensors[eid]; + }; + + auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = get_input(node, 2); + } + return std::make_tuple(get_input(node, 0), get_input(node, 1), bias); + }; + + for (size_t i = 0; i < nodes_.size(); ++i) { + const auto& node = nodes_[i]; + if (node.GetOpType() == "kernel") { + auto op_name = node.GetOpName(); + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = dl_tensors[output_eid]; + bool transa = false; + bool transb = false; + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; + + if (op_name.find("transposed") != std::string::npos) { + transb = true; + } + + if (op_name.find("relu") != std::string::npos) { + epilogue = HIPBLASLT_EPILOGUE_RELU_BIAS; + } else if (op_name.find("gelu") != std::string::npos) { + epilogue = HIPBLASLT_EPILOGUE_GELU_BIAS; + } else if (op_name.find("bias") != std::string::npos) { + epilogue = HIPBLASLT_EPILOGUE_BIAS; + } + + auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != HIPBLASLT_EPILOGUE_DEFAULT); + + tvm::contrib::CallHipblasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr, + b_ptr, bias_ptr, out_ptr, transa, transb, + entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue); + } + } + } + + void Run() override { LOG(FATAL) << "Unreachable"; } +}; + +runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate").set_body_typed(HipblasJSONRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc new file mode 100644 index 000000000000..02d91646518c --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external hipblas utils function + */ +#include "hipblas_utils.h" + +#include +#include + +#include "../../rocm/rocm_common.h" + +namespace tvm { +namespace contrib { + +HipBlasThreadEntry::HipBlasThreadEntry() { CHECK_HIPBLAS_ERROR(hipblasCreate(&handle)); } + +HipBlasThreadEntry::~HipBlasThreadEntry() { + if (handle) { + hipblasDestroy(handle); + handle = nullptr; + } +} + +typedef dmlc::ThreadLocalStore HipBlasThreadStore; + +HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal() { + auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; + HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); + CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, static_cast(stream))); + return retval; +} + +HipBlasLtThreadEntry::HipBlasLtThreadEntry() { + CHECK_HIPBLAS_ERROR(hipblasLtCreate(&handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&matmul_pref_desc)); + ROCM_CALL(hipMalloc(&workspace_ptr, workspace_size)); +} + +HipBlasLtThreadEntry::~HipBlasLtThreadEntry() { + if (handle) { + hipblasLtDestroy(handle); + handle = nullptr; + } + if (matmul_pref_desc) { + hipblasLtMatmulPreferenceDestroy(matmul_pref_desc); + matmul_pref_desc = nullptr; + } + if (workspace_ptr != nullptr) { + hipFree(workspace_ptr); + workspace_ptr = nullptr; + } +} + +typedef dmlc::ThreadLocalStore HipBlasLtThreadStore; + +HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal() { return HipBlasLtThreadStore::Get(); } + +} // namespace contrib + +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h b/src/runtime/contrib/hipblas/hipblas_utils.h new file mode 100644 index 000000000000..66d7afafbd64 --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas_utils.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external hipblas utils function + */ +#ifndef TVM_RUNTIME_CONTRIB_HIPBLAS_HIPBLAS_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_HIPBLAS_HIPBLAS_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace contrib { +inline const char* GetHipblasErrorString(int error) { + switch (error) { + case HIPBLAS_STATUS_NOT_INITIALIZED: + return "HIPBLAS_STATUS_NOT_INITIALIZED"; + case HIPBLAS_STATUS_ALLOC_FAILED: + return "HIPBLAS_STATUS_ALLOC_FAILED"; + case HIPBLAS_STATUS_INVALID_VALUE: + return "HIPBLAS_STATUS_INVALID_VALUE"; + case HIPBLAS_STATUS_ARCH_MISMATCH: + return "HIPBLAS_STATUS_ARCH_MISMATCH"; + case HIPBLAS_STATUS_MAPPING_ERROR: + return "HIPBLAS_STATUS_MAPPING_ERROR"; + case HIPBLAS_STATUS_EXECUTION_FAILED: + return "HIPBLAS_STATUS_EXECUTION_FAILED"; + case HIPBLAS_STATUS_INTERNAL_ERROR: + return "HIPBLAS_STATUS_INTERNAL_ERROR"; + case HIPBLAS_STATUS_NOT_SUPPORTED: + return "HIPBLAS_STATUS_NOT_SUPPORTED"; + } + return "Unrecognized error"; +} + +#ifndef CHECK_HIPBLAS_ERROR +#define CHECK_HIPBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ + ICHECK_EQ(error, HIPBLAS_STATUS_SUCCESS) << "HIPBLAS: " << GetHipblasErrorString(error); \ + } while (0) // ; intentionally left off. +#endif // CHECK_HIPBLAS_ERROR + +struct HipBlasThreadEntry { + HipBlasThreadEntry(); + ~HipBlasThreadEntry(); + hipblasHandle_t handle{nullptr}; + static HipBlasThreadEntry* ThreadLocal(); +}; // HipBlasThreadEntry + +struct HipBlasLtThreadEntry { + HipBlasLtThreadEntry(); + ~HipBlasLtThreadEntry(); + + hipblasLtHandle_t handle{nullptr}; + hipblasLtMatmulPreference_t matmul_pref_desc{nullptr}; + void* workspace_ptr{nullptr}; + // 32MB workspace as suggested by NVIDIA + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace. + static constexpr const size_t workspace_size = 33554432; + + static HipBlasLtThreadEntry* ThreadLocal(); +}; // HipBlasLtThreadEntry + +inline hipDataType GetHipDataType(DLDataType type) { + if (type.code == kDLInt) { + switch (type.bits) { + case 8: + return HIP_R_8I; + case 32: + return HIP_R_32I; + } + } else if (type.code == kDLUInt) { + switch (type.bits) { + case 8: + return HIP_R_8U; + case 32: + return HIP_R_32U; + } + } else if (type.code == kDLFloat) { + switch (type.bits) { + case 16: + return HIP_R_16F; + case 32: + return HIP_R_32F; + case 64: + return HIP_R_64F; + } + } + LOG(FATAL) << "Unsupported hip type"; +} + +inline hipblasDatatype_t GetHipBlasDataType(DLDataType type) { + if (type.code == kDLInt) { + switch (type.bits) { + case 8: + return HIPBLAS_R_8I; + case 32: + return HIPBLAS_R_32I; + } + } else if (type.code == kDLUInt) { + switch (type.bits) { + case 8: + return HIPBLAS_R_8U; + case 32: + return HIPBLAS_R_32U; + } + } else if (type.code == kDLFloat) { + switch (type.bits) { + case 16: + return HIPBLAS_R_16F; + case 32: + return HIPBLAS_R_32F; + case 64: + return HIPBLAS_R_64F; + } + } + LOG(FATAL) << "Unsupported hip type"; +} + +/*! \brief Execute matrix multiply followed by the specified epilogue, using hipBLASLt. */ +void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, + hipblasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, + const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, + bool transb, void* workspace_ptr, size_t workspace_size, + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT); + +} // namespace contrib + +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_HIPBLAS_HIPBLAS_UTILS_H_ diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 561e495a357d..984a2f3323ad 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -360,6 +360,7 @@ TVM_DLL Map GetLibInfo() { {"TVM_DEBUG_WITH_ABI_CHANGE", TVM_INFO_TVM_DEBUG_WITH_ABI_CHANGE}, {"TVM_LOG_BEFORE_THROW", TVM_INFO_TVM_LOG_BEFORE_THROW}, {"USE_ROCBLAS", TVM_INFO_USE_ROCBLAS}, + {"USE_HIPBLAS", TVM_INFO_USE_HIPBLAS}, {"USE_ROCM", TVM_INFO_USE_ROCM}, {"USE_RCCL", TVM_INFO_USE_RCCL}, {"USE_RPC", TVM_INFO_USE_RPC}, diff --git a/tests/python/contrib/test_hipblas.py b/tests/python/contrib/test_hipblas.py new file mode 100644 index 000000000000..63a7553704bf --- /dev/null +++ b/tests/python/contrib/test_hipblas.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.contrib import hipblas + + +def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): + n = 1024 + l = 128 + m = 236 + A = te.placeholder((n, l), name="A", dtype=in_dtype) + B = te.placeholder((l, m), name="B", dtype=in_dtype) + C = hipblas.matmul(A, B, dtype=out_dtype) + s = te.create_schedule(C.op) + + def verify(target="rocm"): + if not tvm.get_global_func("tvm.contrib.hipblas.matmul", True): + print("skip because extern function is not available") + return + dev = tvm.rocm(0) + f = tvm.build(s, [A, B, C], target) + a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), np.dot(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)), rtol=rtol + ) + + verify() + + +def roundoff(v, d): + return int(np.floor((v + d - 1) / d) * d) + + +def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): + A = te.placeholder(Ashape, name="A", dtype=in_dtype) + B = te.placeholder(Bshape, name="B", dtype=in_dtype) + C = hipblas.batch_matmul(A, B, dtype=out_dtype) + s = te.create_schedule(C.op) + + dev = tvm.rocm(0) + f = tvm.build(s, [A, B, C], "rocm") + + if "int" in in_dtype: + a = tvm.nd.array(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev) + b = tvm.nd.array(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev) + else: + a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev) + + c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), + rtol=rtol, + ) + + +@tvm.testing.requires_rocm +def test_matmul_add(): + verify_matmul_add("float", "float", rtol=1e-3) + verify_matmul_add("float16", "float") + verify_matmul_add("float16", "float16", rtol=1e-2) + verify_matmul_add("int8", "int32") + + +@tvm.testing.requires_rocm +def test_batch_matmul(): + if not tvm.get_global_func("tvm.contrib.hipblas.batch_matmul", True): + print("skip because extern function is not available") + return + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul( + (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + verify_batch_matmul( + (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py new file mode 100644 index 000000000000..f43b83802b81 --- /dev/null +++ b/tests/python/relax/test_codegen_hipblas.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relax +from tvm.relax.backend.contrib.hipblas import partition_for_hipblas +from tvm.relax.testing import get_relax_matmul_module +from tvm.script import relax as R + +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + + +@pytest.fixture(autouse=True) +def reset_seed(): + np.random.seed(0) + + +pytestmark = tvm.testing.requires_hipblas.marks() + + +def build_and_run(mod, inputs_np, target, legalize=False): + dev = tvm.device(target, 0) + with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def get_result_with_relax_cublas_offload(mod, np_inputs): + mod = partition_for_hipblas(mod) + mod = relax.transform.RunCodegen()(mod) + + return build_and_run(mod, np_inputs, "rocm") + + +def _to_concrete_shape(symbolic_shape, var_table): + result = [] + for dim in symbolic_shape: + if not isinstance(dim, tvm.tir.expr.Var): + result.append(dim) + continue + + if dim not in var_table: + var_table[dim] = np.random.randint(10, 50) + result.append(var_table[dim]) + + return tuple(result) + + +_vars = { + "a": tvm.tir.expr.Var("a", "int64"), + "b": tvm.tir.expr.Var("b", "int64"), +} + + +_epilogue_table = { + "none": (False, None), + "bias": (True, None), + "relu": (True, R.nn.relu), + "gelu": (True, R.nn.gelu), +} + + +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, epilogue", + [ + # Regular + ((8, 8), (8, 8), False, "none"), + ((_vars["a"], 6), (6, 16), False, "bias"), + # Transposed + ((4, 16), (16, 128), True, "relu"), + ((35, 8), (8, 8), True, "gelu"), + # # 3D x 3D + ((6, 32, 8), (6, 8, 10), False, "bias"), + ((6, 32, 8), (6, 8, 10), True, "none"), + ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), + # ND x ND + ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + # ND x 2D + ((5, 3, 32, 8), (8, 10), False, "none"), + ], +) +@pytest.mark.parametrize( + "in_dtype, out_dtype", + [ + ("float16", "float16"), + ("float32", "float32"), + ], +) +def test_matmul_offload( + x_shape, + y_shape, + transpose_y, + epilogue, + in_dtype, + out_dtype, +): + with_bias, activation = _epilogue_table[epilogue] + var_table = {} + concrete_x_shape = _to_concrete_shape(x_shape, var_table) + concrete_y_shape = _to_concrete_shape(y_shape, var_table) + x = np.random.randn(*concrete_x_shape).astype(in_dtype) + y = np.random.randn(*concrete_y_shape).astype(in_dtype) + + if transpose_y: + y = np.swapaxes(y, -2, -1) + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + if with_bias: + bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype) + args = (x, y, bias) + else: + bias = None + args = (x, y) + + mod = get_relax_matmul_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + bias_shape=bias.shape if with_bias else None, + transposed_y=transpose_y, + activation=activation, + ) + + out = get_result_with_relax_cublas_offload(mod, args) + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +def test_hipblas_partition_matmul_without_bias(): + # hipBLAS does not handle 2D bias (residual input) + mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32)) + mod = partition_for_hipblas(mod) + + # R.add is still in the main function + assert len(mod["main"].body.blocks[0].bindings) == 2 + + +if __name__ == "__main__": + tvm.testing.main()