From ff21e1fd41f118dbbaf55d8f02a9669842ef565f Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Thu, 17 Aug 2017 21:16:51 -0700 Subject: [PATCH] Changed FullyConnected to use new linalg gemm, plus TensorCore if fp16 I/O. (#7505) * Converted FullyConnected to use new linalg gemm, plus TensorCore if fp16 I/O. * Simplified linalg_gemm interface to ease integration. * Correcting code in response to comments. * Removing Transpose(), leaving trailing req arg with default of kWriteTo. --- src/common/cuda_utils.h | 34 ++++++++--- src/operator/fully_connected-inl.h | 14 +++-- src/operator/linalg.h | 10 ++++ src/operator/linalg_impl.h | 86 +++++++++++++++++++++++++++ tests/python/gpu/test_operator_gpu.py | 5 ++ 5 files changed, 138 insertions(+), 11 deletions(-) diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 483390fc9bea..0213c73177b3 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -200,7 +200,7 @@ inline DType __device__ CudaMin(DType a, DType b) { { \ cublasStatus_t e = (func); \ CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ - << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ + << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \ } /*! @@ -213,7 +213,7 @@ inline DType __device__ CudaMin(DType a, DType b) { { \ cusolverStatus_t e = (func); \ CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ - << "cuSolver: " << common::cuda::CusolverGetErrorString(e); \ + << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \ } /*! @@ -226,7 +226,7 @@ inline DType __device__ CudaMin(DType a, DType b) { { \ curandStatus_t e = (func); \ CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ - << "cuRAND: " << common::cuda::CurandGetErrorString(e); \ + << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \ } #if !defined(_MSC_VER) @@ -304,11 +304,31 @@ inline bool SupportsTensorCore(int device_id) { * \return whether to allow TensorCore algo (if not specified by the Operator locally). */ inline bool GetEnvAllowTensorCore() { - // Use of optional here permits: "0", "1", "true" and "false" to all be legal. - bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT; - return dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", - dmlc::optional(default_value)).value(); + // Since these statics are in the '.h' file, they will exist and will be set + // separately in each compilation unit. Not ideal, but cleaner than creating a + // cuda_utils.cc solely to have a single instance and initialization. + static bool allow_tensor_core = false; + static bool is_set = false; + if (!is_set) { + // Use of optional here permits: "0", "1", "true" and "false" to all be legal. + bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT; + allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", + dmlc::optional(default_value)).value(); + is_set = true; + } + return allow_tensor_core; +} + +#if CUDA_VERSION >= 9000 +// Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous. +inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) { + auto handle_math_mode = CUBLAS_DEFAULT_MATH; + CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode)); + CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type)); + return handle_math_mode; } +#endif + #endif // MXNET_USE_CUDA #if MXNET_USE_CUDNN diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index e2fab9f1f7dd..cf13655d9c97 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -33,7 +33,7 @@ #include #include "./operator_common.h" #include "./elemwise_op_common.h" - +#include "linalg.h" namespace mxnet { namespace op { @@ -96,7 +96,9 @@ class FullyConnectedOp : public Operator { Tensor wmat = in_data[fullc::kWeight].get(s); Tensor out = out_data[fullc::kOut].get_with_shape( Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); - out = dot(data, wmat.T()); + // Legacy approach shown here for comparison: + // out = dot(data, wmat.T()); + linalg_gemm(data, wmat, out, false, true, s); if (!param_.no_bias) { Tensor bias = in_data[fullc::kBias].get(s); out += repmat(bias, data.size(0)); @@ -136,7 +138,9 @@ class FullyConnectedOp : public Operator { CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight Tensor gwmat = in_grad[fullc::kWeight].get(s); - Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data)); + // Legacy approach shown here for comparison: + // Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data)); + linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]); // gradient of bias if (!param_.no_bias) { Tensor gbias = in_grad[fullc::kBias].get(s); @@ -145,7 +149,9 @@ class FullyConnectedOp : public Operator { // gradient of data Tensor gdata = in_grad[fullc::kData].get_with_shape( Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); - Assign(gdata, req[fullc::kData], dot(grad, wmat)); + // Legacy approach shown here for comparison: + // Assign(gdata, req[fullc::kData], dot(grad, wmat)); + linalg_gemm(grad, wmat, gdata, false, false, s, req[fullc::kData]); } private: diff --git a/src/operator/linalg.h b/src/operator/linalg.h index 9284a5825d2c..76acf7b98f41 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -26,6 +26,8 @@ #define MXNET_OPERATOR_LINALG_H_ #include +#include + #include "./c_lapack_api.h" using namespace mshadow; @@ -62,6 +64,14 @@ void linalg_batch_gemm(const Tensor& A, const Tensor& C, DType alpha, DType beta, bool tA, bool tB, Stream *s = 0); +template +inline void linalg_gemm(const Tensor& A, + const Tensor& B, + const Tensor& C, + bool tA, bool tB, + Stream *s = 0, + mxnet::OpReqType req = mxnet::kWriteTo); + //////////////////////////////// TRSM //////////////////////////////////////////// // CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index affa7941640b..1e3b0e66e641 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -25,8 +25,12 @@ #ifndef MXNET_OPERATOR_LINALG_IMPL_H_ #define MXNET_OPERATOR_LINALG_IMPL_H_ +#include + #include +#include "../common/cuda_utils.h" + // Convenience functions. inline void linalg_check_batch_size(int A, int B, int C) { CHECK_EQ(A, B) << "Inconsistent batch size between arguments to linear algebra operator"; @@ -108,6 +112,55 @@ void linalg_gemm(const Tensor& A, const Tensor for DType=mshadow::half::half_t. +template<> inline +void linalg_gemm(const Tensor& A, + const Tensor& B, + const Tensor& C, + mshadow::half::half_t alpha, + mshadow::half::half_t beta, + bool tA, bool tB, Stream *s) { + using namespace mxnet; + using mshadow::gpu; + CHECK_NOTNULL(s); + check_gemm(A, B, C, alpha, beta, tA, tB); + +#if CUDA_VERSION >= 7050 + auto blas_handle = Stream::GetBlasHandle(s); +#if CUDA_VERSION >= 9000 + auto cublas_math_mode = GetEnvAllowTensorCore() ? CUBLAS_TENSOR_OP_MATH + : CUBLAS_DEFAULT_MATH; + auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode); +#endif + + // pseudo-fp16 (fp32 math with fp16 I/O) + float alpha_f = float(alpha); // NOLINT(*) + float beta_f = float(beta); // NOLINT(*) + + // As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype. +#if CUDA_VERSION >= 8000 + cudaDataType_t half_datatype = CUDA_R_16F; +#else + cublasDataType_t half_datatype = CUBLAS_DATA_HALF; +#endif + CUBLAS_CALL(cublasSgemmEx(blas_handle, + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), + C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), + &alpha_f, + B.dptr_, half_datatype, B.stride_, + A.dptr_, half_datatype, A.stride_, + &beta_f, + C.dptr_, half_datatype, C.stride_)); +#if CUDA_VERSION >= 9000 + SetCublasMathMode(blas_handle, previous_math_mode); +#endif +#else + LOG(FATAL) << "FP16 gemm requires CUDA version >= 7.5!"; +#endif // CUDA_VERSION >= 7050 +} + + #define LINALG_GPU_BATCH_GEMM(fname, DType) \ template<> inline \ void linalg_batch_gemm(const Tensor& A, const Tensor& B, \ @@ -246,6 +299,39 @@ LINALG_GPU_BATCH_TRSM(DtrsmBatched, double) #endif +/*! + * \brief Performs gemm, setting alpha and beta as appropriate for `req`. + * + * \param A the first operand of the gemm + * \param B the second operand of the gemm + * \param C the data to be assigned + * \tparam tA whether the `A` operand should be transposed first. + * \tparam tB whether the `B` operand should be transposed first. + * \tparam s the stream to perform the operation + * \param req the assignment request + */ +template +inline void linalg_gemm(const Tensor& A, + const Tensor& B, + const Tensor& C, + bool tA, bool tB, Stream *s, + mxnet::OpReqType req) { + using namespace mxnet; + switch (req) { + case kNullOp: + break; + case kWriteTo: + case kWriteInplace: + linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s); + break; + case kAddTo: + linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s); + break; + default: + LOG(FATAL) << "not reached"; + } +} + //////////////////////////////// TRMM //////////////////////////////////////////// // CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 866f6ad8abc0..81492fe6bbdb 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -926,6 +926,11 @@ def test_fullyconnected_with_type(): {'ctx': mx.cpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float64}}, {'ctx': mx.cpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float32}}] check_consistency(sym, ctx_list) + # Sizes are divisible by 8 to test TensorCore on Volta GPU. + sym = mx.sym.FullyConnected(num_hidden=8, name='inner') + ctx_list = [{'ctx': mx.gpu(0), 'inner_data': (16, 24), 'type_dict': {'inner_data': np.float16}}, + {'ctx': mx.cpu(0), 'inner_data': (16, 24), 'type_dict': {'inner_data': np.float32}}] + check_consistency(sym, ctx_list) def test_activation_with_type():