From 48c868a4940c2f1a437bf508cdd1249fe5786e11 Mon Sep 17 00:00:00 2001 From: Fisher Date: Fri, 12 May 2023 08:40:18 +0000 Subject: [PATCH] cublas gemm support fp64 --- cinn/runtime/cuda/cublas_util.h | 56 ++++ cinn/runtime/cuda/cuda_util.cc | 4 + python/tests/ops/test_matmul_op.py | 441 +++++++++++++---------------- 3 files changed, 253 insertions(+), 248 deletions(-) diff --git a/cinn/runtime/cuda/cublas_util.h b/cinn/runtime/cuda/cublas_util.h index 84c185a6fd..24ae8774c2 100644 --- a/cinn/runtime/cuda/cublas_util.h +++ b/cinn/runtime/cuda/cublas_util.h @@ -52,6 +52,23 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype, reinterpret_cast(&beta), reinterpret_cast(C), ldc); + } else if (dtype == CUDA_R_64F) { + const double alpha_fp64 = static_cast(alpha); + const double beta_fp64 = static_cast(beta); + return cublasDgemm(handle, + transa, + transb, + m, + n, + k, + &alpha_fp64, + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + &beta_fp64, + reinterpret_cast(C), + ldc); } else if (dtype == CUDA_R_16F) { common::float16 alpha_fp16{alpha}; common::float16 beta_fp16{beta}; @@ -135,6 +152,27 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, ldc, strideC, batchCount); + } else if (dtype == CUDA_R_64F) { + const double alpha_fp64 = static_cast(alpha); + const double beta_fp64 = static_cast(beta); + return cublasDgemmStridedBatched(handle, + transa, + transb, + m, + n, + k, + &alpha_fp64, + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + &beta_fp64, + reinterpret_cast(C), + ldc, + strideC, + batchCount); } else if (dtype == CUDA_R_16F) { common::float16 alpha_fp16{alpha}; common::float16 beta_fp16{beta}; @@ -220,6 +258,24 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, reinterpret_cast(C), ldc, batchCount); + } else if (dtype == CUDA_R_64F) { + const double alpha_fp64 = static_cast(alpha); + const double beta_fp64 = static_cast(beta); + return cublasDgemmBatched(handle, + transa, + transb, + m, + n, + k, + &alpha_fp64, + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + &beta_fp64, + reinterpret_cast(C), + ldc, + batchCount); } else if (dtype == CUDA_R_16F) { __half alpha_fp16{alpha}; __half beta_fp16{beta}; diff --git a/cinn/runtime/cuda/cuda_util.cc b/cinn/runtime/cuda/cuda_util.cc index 79d4f7b90b..59a473c7be 100644 --- a/cinn/runtime/cuda/cuda_util.cc +++ b/cinn/runtime/cuda/cuda_util.cc @@ -160,6 +160,8 @@ void cinn_call_cublas(void *v_args, cuda_dtype = CUDA_R_16F; } else if (is_float && bytes == sizeof(float)) { cuda_dtype = CUDA_R_32F; + } else if (is_float && bytes == sizeof(double)) { + cuda_dtype = CUDA_R_64F; } else if (is_bfloat16) { cuda_dtype = CUDA_R_16BF; } else { @@ -311,6 +313,8 @@ void cinn_call_batched_cublas(void *v_args, cuda_dtype = CUDA_R_16F; } else if (is_float && bytes == sizeof(float)) { cuda_dtype = CUDA_R_32F; + } else if (is_float && bytes == sizeof(double)) { + cuda_dtype = CUDA_R_64F; } else if (is_bfloat16) { cuda_dtype = CUDA_R_16BF; } else { diff --git a/python/tests/ops/test_matmul_op.py b/python/tests/ops/test_matmul_op.py index efe9eb4f87..d3637a1481 100755 --- a/python/tests/ops/test_matmul_op.py +++ b/python/tests/ops/test_matmul_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper import paddle import paddle.nn.functional as F import cinn @@ -28,281 +29,225 @@ "x86 test will be skipped due to timeout.") class TestMatmulOp(OpTest): def setUp(self): - self.init_case() + # print(f'{self.__class__.__name__}: {self.case}') + self.prepare_inputs() - def init_case(self): - self.inputs = { - "x": np.random.random([4, 16]).astype("float32"), - "y": np.random.random([16, 32]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False + def prepare_inputs(self): + self.x_np = self.random(self.case["x_shape"], self.case["dtype"]) + self.y_np = self.random(self.case["y_shape"], self.case["dtype"]) def paddle_func(self, x, y): return paddle.matmul( - x, y, transpose_x=self.transpose_x, transpose_y=self.transpose_y) + x, + y, + transpose_x=self.case["transx"], + transpose_y=self.case["transy"]) def build_paddle_program(self, target): - x = paddle.to_tensor(self.inputs["x"], stop_gradient=True) - y = paddle.to_tensor(self.inputs["y"], stop_gradient=True) - + x = paddle.to_tensor(self.x_np, stop_gradient=True) + y = paddle.to_tensor(self.y_np, stop_gradient=True) out = self.paddle_func(x, y) - self.paddle_outputs = [out] def cinn_func(self, builder, x, y): return builder.matmul( - x, y, transpose_x=self.transpose_x, transpose_y=self.transpose_y) + x, + y, + transpose_x=self.case["transx"], + transpose_y=self.case["transy"]) def build_cinn_program(self, target): builder = NetBuilder("matmul") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") + x = builder.create_input( + self.nptype2cinntype(self.case["dtype"]), self.case["x_shape"], + "x") + y = builder.create_input( + self.nptype2cinntype(self.case["dtype"]), self.case["y_shape"], + "y") out = self.cinn_func(builder, x, y) - prog = builder.build() res = self.get_cinn_output( - prog, - target, [x, y], [self.inputs["x"], self.inputs["y"]], [out], - passes=list()) - - self.cinn_outputs = [res[0]] + prog, target, [x, y], [self.x_np, self.y_np], [out], passes=list()) + self.cinn_outputs = res def test_check_results(self): - self.check_outputs_and_grads() - - -class TestMatmulCase1(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([2, 16]).astype("float32"), - "y": np.random.random([16, 2]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase2(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([5, 4, 16]).astype("float32"), - "y": np.random.random([5, 16, 32]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase3(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([5, 16, 4]).astype("float32"), - "y": np.random.random([5, 16, 32]).astype("float32") - } - self.transpose_x = True - self.transpose_y = False - - -class TestMatmulCase4(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([16, 4]).astype("float32"), - "y": np.random.random([16, 32]).astype("float32") - } - self.transpose_x = True - self.transpose_y = False - - -class TestMatmulCase5(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([4, 16]).astype("float32"), - "y": np.random.random([32, 16]).astype("float32") - } - self.transpose_x = False - self.transpose_y = True - - -class TestMatmulCase6(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([16, 4]).astype("float32"), - "y": np.random.random([32, 16]).astype("float32") - } - self.transpose_x = True - self.transpose_y = True - - -class TestMatmulCase7(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 16, 4]).astype("float32"), - "y": np.random.random([1, 4, 16]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase8(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 16, 4]).astype("float32"), - "y": np.random.random([1, 4, 16]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase9(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 16, 4]).astype("float32"), - "y": np.random.random([4, 16]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase10(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 16, 4]).astype("float32"), - "y": np.random.random([4, 16]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase11(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 4, 16]).astype("float32"), - "y": np.random.random([4, 16]).astype("float32") - } - self.transpose_x = True - self.transpose_y = False - - -class TestMatmulCase12(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 16, 4]).astype("float32"), - "y": np.random.random([16, 1]).astype("float32") - } - self.transpose_x = True - self.transpose_y = False - - -class TestMatmulCase13(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([4, 16]).astype("float32"), - "y": np.random.random([16, 1]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase14(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([10, 1, 128, 64]).astype("float32"), - "y": np.random.random([10, 12, 64, 128]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase15(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([10, 12, 128, 64]).astype("float32"), - "y": np.random.random([10, 12, 128, 64]).astype("float32") - } - self.transpose_x = False - self.transpose_y = True - - -class TestMatmulCase16(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([10, 12, 64, 128]).astype("float32"), - "y": np.random.random([10, 12, 128, 64]).astype("float32") - } - self.transpose_x = True - self.transpose_y = True - - -class TestMatmulCase17(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([128]).astype("float32"), - "y": np.random.random([10, 12, 128, 64]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase18(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([128]).astype("float32"), - "y": np.random.random([10, 12, 128, 64]).astype("float32") - } - self.transpose_x = True - self.transpose_y = False - - -class TestMatmulCase19(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([128]).astype("float32"), - "y": np.random.random([128]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - - -class TestMatmulCase20(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([128]).astype("float32"), - "y": np.random.random([128]).astype("float32") - } - self.transpose_x = True - self.transpose_y = True + max_relative_error = self.case[ + "max_relative_error"] if "max_relative_error" in self.case else 1e-5 + self.check_outputs_and_grads(max_relative_error=max_relative_error) + + +class TestMatmulOpShapeDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestMatmulOpCase" + self.cls = TestMatmulOp + self.inputs = [ + { + "x_shape": [1024], + "y_shape": [1024], + }, + { + "x_shape": [4, 4], + "y_shape": [4, 4], + }, + { + "x_shape": [4, 16], + "y_shape": [16, 32], + }, + { + "x_shape": [5, 4, 16], + "y_shape": [5, 16, 32], + }, + { + # Matrix mul row vector + "x_shape": [4, 16], + "y_shape": [16], + }, + { + # Matrix mul col vector + "x_shape": [4, 16], + "y_shape": [16, 1], + }, + { + "x_shape": [8, 16, 4], + "y_shape": [1, 4, 16], + }, + { + "x_shape": [1, 1, 1, 1], + "y_shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + # { + # "dtype": "bfloat16", + # }, + { + "dtype": "float16", + "max_relative_error": 1e-3 + }, + { + "dtype": "float32", + }, + { + "dtype": "float64", + }, + ] + self.attrs = [ + { + "transx": False, + "transy": False, + }, + ] + + +class TestMatmulOpTrans(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestMatmulOpCase" + self.cls = TestMatmulOp + self.inputs = [ + { + "x_shape": [16, 4], + "y_shape": [16, 32], + "transx": True, + "transy": False, + }, + { + "x_shape": [5, 16, 4], + "y_shape": [5, 16, 32], + "transx": True, + "transy": False, + }, + { + "x_shape": [8, 4, 16], + "y_shape": [4, 16], + "transx": True, + "transy": False, + }, + { + "x_shape": [4, 16], + "y_shape": [32, 16], + "transx": False, + "transy": True, + }, + { + "x_shape": [10, 12, 128, 64], + "y_shape": [10, 12, 128, 64], + "transx": False, + "transy": True, + }, + { + "x_shape": [16, 4], + "y_shape": [32, 16], + "transx": True, + "transy": True, + }, + { + "x_shape": [10, 12, 64, 128], + "y_shape": [10, 12, 128, 64], + "transx": True, + "transy": True, + }, + { + "x_shape": [128], + "y_shape": [10, 12, 128, 64], + "transx": True, + "transy": False, + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [] class TestMatmulTransposePass(TestMatmulOp): - def init_case(self): - self.inputs = { - "x": np.random.random([10, 1, 128, 64]).astype("float32"), - "y": np.random.random([10, 12, 64, 128]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - self.perm = [0, 1, 3, 2] - def paddle_func(self, x, y): out = paddle.matmul( - x, y, transpose_x=self.transpose_x, transpose_y=self.transpose_y) - return paddle.transpose(out, self.perm) + x, + y, + transpose_x=self.case["transx"], + transpose_y=self.case["transy"]) + return paddle.transpose(out, self.case["perm"]) def cinn_func(self, builder, x, y): out = builder.matmul( - x, y, transpose_x=self.transpose_x, transpose_y=self.transpose_y) - return builder.transpose(out, self.perm) - - -class TestMatmulTransposePassCase1(TestMatmulTransposePass): - def init_case(self): - self.inputs = { - "x": np.random.random([32, 64]).astype("float32"), - "y": np.random.random([64, 128]).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - self.perm = [1, 0] + x, + y, + transpose_x=self.case["transx"], + transpose_y=self.case["transy"]) + return builder.transpose(out, self.case["perm"]) + + +class TestMatmulTransposePassAll(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestMatmulTransposePassCase" + self.cls = TestMatmulTransposePass + self.inputs = [ + { + "x_shape": [32, 64], + "y_shape": [64, 128], + "perm": [1, 0], + "transx": False, + "transy": False, + }, + { + "x_shape": [10, 1, 128, 64], + "y_shape": [10, 12, 64, 128], + "perm": [0, 1, 3, 2], + "transx": False, + "transy": False, + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestMatmulOpShapeDtype().run() + TestMatmulOpTrans().run() + TestMatmulTransposePassAll().run()