Skip to content

Commit

Permalink
[ROCm] hipBLAS integration (#17290)
Browse files Browse the repository at this point in the history
This commit integrates hipBLAS into TVM. The minimum ROCm version
requirement is 6.0.

Co-authored-by: Lesheng Jin <[email protected]>
  • Loading branch information
MasterJH5574 and LeshengJin authored Aug 22, 2024
1 parent 0f037a6 commit 8db545d
Show file tree
Hide file tree
Showing 15 changed files with 1,514 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_CURAND "Build with cuRAND" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_HIPBLAS "Build with ROCM:HIPBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" ON)
tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_LIBTORCH "Build with libtorch support" OFF)
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ function(add_lib_info src_file)
TVM_INFO_TVM_DEBUG_WITH_ABI_CHANGE="${TVM_DEBUG_WITH_ABI_CHANGE}"
TVM_INFO_TVM_LOG_BEFORE_THROW="${TVM_LOG_BEFORE_THROW}"
TVM_INFO_USE_ROCBLAS="${USE_ROCBLAS}"
TVM_INFO_USE_HIPBLAS="${USE_HIPBLAS}"
TVM_INFO_USE_ROCM="${USE_ROCM}"
TVM_INFO_USE_RCCL="${USE_RCCL}"
TVM_INFO_USE_RPC="${USE_RPC}"
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/ROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ if(USE_ROCM)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY})
endif(USE_ROCBLAS)

if(USE_HIPBLAS)
message(STATUS "Build with HIPBLAS support")
tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRC src/relax/backend/contrib/hipblas/*.cc)
list(APPEND COMPILER_SRCS ${HIPBLAS_CONTRIB_SRC})
tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRCS src/runtime/contrib/hipblas/*.cc)
list(APPEND RUNTIME_SRCS ${HIPBLAS_CONTRIB_SRCS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLAS_LIBRARY})
if(NOT ROCM_HIPBLASLT_LIBRARY STREQUAL "ROCM_HIPBLASLT_LIBRARY-NOTFOUND")
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLASLT_LIBRARY})
endif()
endif(USE_HIPBLAS)

if(USE_THRUST)
message(STATUS "Build with rocThrust support")
# We need to override CXX to hipcc. This is required by rocthrust
Expand Down
4 changes: 4 additions & 0 deletions cmake/utils/FindROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ macro(find_rocm use_rocm)
endif()
find_library(ROCM_MIOPEN_LIBRARY MIOpen ${__rocm_sdk}/lib)
find_library(ROCM_ROCBLAS_LIBRARY rocblas ${__rocm_sdk}/lib)
find_library(ROCM_HIPBLAS_LIBRARY hipblas ${__rocm_sdk}/lib)
find_library(ROCM_HIPBLASLT_LIBRARY hipblaslt ${__rocm_sdk}/lib)
find_library(ROCM_HSA_LIBRARY hsa-runtime64 ${__rocm_sdk}/lib)

if(ROCM_HIPHCC_LIBRARY)
Expand All @@ -66,5 +68,7 @@ macro(find_rocm use_rocm)
message(STATUS "Found ROCM_HIPHCC_LIBRARY=" ${ROCM_HIPHCC_LIBRARY})
message(STATUS "Found ROCM_MIOPEN_LIBRARY=" ${ROCM_MIOPEN_LIBRARY})
message(STATUS "Found ROCM_ROCBLAS_LIBRARY=" ${ROCM_ROCBLAS_LIBRARY})
message(STATUS "Found ROCM_HIPBLAS_LIBRARY=" ${ROCM_HIPBLAS_LIBRARY})
message(STATUS "Found ROCM_HIPBLASLT_LIBRARY=" ${ROCM_HIPBLASLT_LIBRARY})
endif(ROCM_FOUND)
endmacro(find_rocm)
86 changes: 86 additions & 0 deletions python/tvm/contrib/hipblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""External function interface to hipBLAS libraries."""
import tvm
from tvm import te


def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return te.extern(
(n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.hipblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
dtype=dtype,
name="matmul_hipblas",
)


def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
dtype = dtype if dtype is not None else lhs.dtype
return te.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.hipblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb
),
dtype=dtype,
name="batch_matmul_hipblas",
)
180 changes: 180 additions & 0 deletions python/tvm/relax/backend/contrib/hipblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Pattern table for hipblas backend"""
import operator
from functools import reduce

import tvm
from tvm.relax import transform
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_matmul_pattern
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint: disable=unused-argument
"""Check if dtypes in the given workload are supported by hipblas BYOC."""
if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
# The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
# return out_dtype != "e5m2_float8"
return False
return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
lhs_dtype == "int8" and rhs_dtype == "int8"
)


def _check_matmul(context: PatternCheckContext) -> bool:
if has_leaking_intermediate_variables(context):
return False
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]
matmul_call = context.annotated_expr["root"]

lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
out_dtype = matmul_call.struct_info.dtype
if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
return False

lhs_shape = lhs.struct_info.shape.values
rhs_shape = rhs.struct_info.shape.values

if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
# Reduction axis must be constant
return False

if lhs_dtype == "int8" and rhs_dtype == "int8":
return False
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
return False

lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)

if "bias" in context.annotated_expr:
if lhs_dtype == "int8" and rhs_dtype == "int8":
# Non-default epilogue not supported for IGEMM
return False
bias = context.annotated_expr["bias"]
bias_shape = bias.struct_info.shape.values
bias_batches = reduce(operator.mul, bias_shape[:-1], 1)
if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1:
# hipblas only supports bias vector
return False

# hipblasLt does not seem to support batched GEMM with one of matrices having
# one batch (with batch_stride 0). So for batched GEMM, the two batch counts
# must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by
# flattening all batch axes into the M axis.
return (
isinstance(lhs_batches, tvm.tir.Var)
or isinstance(rhs_batches, tvm.tir.Var)
or (int(lhs_batches) == int(rhs_batches))
or (lhs_batches >= 1 and rhs_batches == 1)
)


register_patterns(
[
(
"hipblas.matmul",
*make_matmul_pattern(
with_bias=False,
),
_check_matmul,
),
(
"hipblas.matmul_bias",
*make_matmul_pattern(
with_bias=True,
),
_check_matmul,
),
(
"hipblas.matmul_bias_relu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
),
_check_matmul,
),
(
"hipblas.matmul_bias_gelu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
),
_check_matmul,
),
(
"hipblas.matmul_transposed",
*make_matmul_pattern(
with_bias=False,
transposed_rhs=True,
),
_check_matmul,
),
(
"hipblas.matmul_transposed_bias",
*make_matmul_pattern(
with_bias=True,
transposed_rhs=True,
),
_check_matmul,
),
(
"hipblas.matmul_transposed_bias_relu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
transposed_rhs=True,
),
_check_matmul,
),
(
"hipblas.matmul_transposed_bias_gelu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
transposed_rhs=True,
),
_check_matmul,
),
]
)


def partition_for_hipblas(mod):
"""
Partition the input module into hipblas-supported subgraphs.
Parameters
----------
mod: tvm.IRModule
The IRModule to be partitioned.
Returns
-------
mod: tvm.IRModule
The resulting IRModule, containing partitioned subgraphs to be
offloaded to the hipblas backend.
"""

patterns = get_patterns_with_prefix("hipblas")
return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod)
3 changes: 3 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,9 @@ def _multi_gpu_exists():
parent_features="rocm",
)

# Mark a test as requiring the hipBLAS library.
requires_hipblas = Feature("hipblas", "hipBLAS", cmake_flag="USE_HIPBLAS", parent_features="rocm")

# Mark a test as requiring the metal runtime
requires_metal = Feature(
"metal",
Expand Down
Loading

0 comments on commit 8db545d

Please sign in to comment.