Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Changed FullyConnected to use new linalg gemm, plus TensorCore if fp1…
Browse files Browse the repository at this point in the history
…6 I/O. (#7505)

* Converted FullyConnected to use new linalg gemm, plus TensorCore if fp16 I/O.

* Simplified linalg_gemm interface to ease integration.

* Correcting code in response to comments.

* Removing Transpose(), leaving trailing req arg with default of kWriteTo.
  • Loading branch information
DickJC123 authored and piiswrong committed Aug 18, 2017
1 parent 56eae58 commit ff21e1f
Showing 5 changed files with 138 additions and 11 deletions.
34 changes: 27 additions & 7 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -200,7 +200,7 @@ inline DType __device__ CudaMin(DType a, DType b) {
{ \
cublasStatus_t e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBLAS: " << common::cuda::CublasGetErrorString(e); \
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
}

/*!
@@ -213,7 +213,7 @@ inline DType __device__ CudaMin(DType a, DType b) {
{ \
cusolverStatus_t e = (func); \
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
<< "cuSolver: " << common::cuda::CusolverGetErrorString(e); \
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
}

/*!
@@ -226,7 +226,7 @@ inline DType __device__ CudaMin(DType a, DType b) {
{ \
curandStatus_t e = (func); \
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
<< "cuRAND: " << common::cuda::CurandGetErrorString(e); \
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
}

#if !defined(_MSC_VER)
@@ -304,11 +304,31 @@ inline bool SupportsTensorCore(int device_id) {
* \return whether to allow TensorCore algo (if not specified by the Operator locally).
*/
inline bool GetEnvAllowTensorCore() {
// Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
return dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
dmlc::optional<bool>(default_value)).value();
// Since these statics are in the '.h' file, they will exist and will be set
// separately in each compilation unit. Not ideal, but cleaner than creating a
// cuda_utils.cc solely to have a single instance and initialization.
static bool allow_tensor_core = false;
static bool is_set = false;
if (!is_set) {
// Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
dmlc::optional<bool>(default_value)).value();
is_set = true;
}
return allow_tensor_core;
}

#if CUDA_VERSION >= 9000
// Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
auto handle_math_mode = CUBLAS_DEFAULT_MATH;
CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
return handle_math_mode;
}
#endif

#endif // MXNET_USE_CUDA

#if MXNET_USE_CUDNN
14 changes: 10 additions & 4 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
#include <utility>
#include "./operator_common.h"
#include "./elemwise_op_common.h"

#include "linalg.h"

namespace mxnet {
namespace op {
@@ -96,7 +96,9 @@ class FullyConnectedOp : public Operator {
Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
out = dot(data, wmat.T());
// Legacy approach shown here for comparison:
// out = dot(data, wmat.T());
linalg_gemm(data, wmat, out, false, true, s);
if (!param_.no_bias) {
Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s);
out += repmat(bias, data.size(0));
@@ -136,7 +138,9 @@ class FullyConnectedOp : public Operator {
CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
// gradient of weight
Tensor<xpu, 2, DType> gwmat = in_grad[fullc::kWeight].get<xpu, 2, DType>(s);
Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data));
// Legacy approach shown here for comparison:
// Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data));
linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]);
// gradient of bias
if (!param_.no_bias) {
Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);
@@ -145,7 +149,9 @@ class FullyConnectedOp : public Operator {
// gradient of data
Tensor<xpu, 2, DType> gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
Assign(gdata, req[fullc::kData], dot(grad, wmat));
// Legacy approach shown here for comparison:
// Assign(gdata, req[fullc::kData], dot(grad, wmat));
linalg_gemm(grad, wmat, gdata, false, false, s, req[fullc::kData]);
}

private:
10 changes: 10 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@
#define MXNET_OPERATOR_LINALG_H_

#include <mshadow/tensor.h>
#include <mxnet/op_attr_types.h>

#include "./c_lapack_api.h"
using namespace mshadow;

@@ -62,6 +64,14 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DTyp
const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);

template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C,
bool tA, bool tB,
Stream<xpu> *s = 0,
mxnet::OpReqType req = mxnet::kWriteTo);

//////////////////////////////// TRSM ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation
86 changes: 86 additions & 0 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
@@ -25,8 +25,12 @@
#ifndef MXNET_OPERATOR_LINALG_IMPL_H_
#define MXNET_OPERATOR_LINALG_IMPL_H_

#include <mxnet/op_attr_types.h>

#include <algorithm>

#include "../common/cuda_utils.h"

// Convenience functions.
inline void linalg_check_batch_size(int A, int B, int C) {
CHECK_EQ(A, B) << "Inconsistent batch size between arguments to linear algebra operator";
@@ -108,6 +112,55 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2
LINALG_GPU_GEMM(Sgemm, float)
LINALG_GPU_GEMM(Dgemm, double)

// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half::half_t>& A,
const Tensor<gpu, 2, mshadow::half::half_t>& B,
const Tensor<gpu, 2, mshadow::half::half_t>& C,
mshadow::half::half_t alpha,
mshadow::half::half_t beta,
bool tA, bool tB, Stream<gpu> *s) {
using namespace mxnet;
using mshadow::gpu;
CHECK_NOTNULL(s);
check_gemm(A, B, C, alpha, beta, tA, tB);

#if CUDA_VERSION >= 7050
auto blas_handle = Stream<gpu>::GetBlasHandle(s);
#if CUDA_VERSION >= 9000
auto cublas_math_mode = GetEnvAllowTensorCore() ? CUBLAS_TENSOR_OP_MATH
: CUBLAS_DEFAULT_MATH;
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);
#endif

// pseudo-fp16 (fp32 math with fp16 I/O)
float alpha_f = float(alpha); // NOLINT(*)
float beta_f = float(beta); // NOLINT(*)

// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
#if CUDA_VERSION >= 8000
cudaDataType_t half_datatype = CUDA_R_16F;
#else
cublasDataType_t half_datatype = CUBLAS_DATA_HALF;
#endif
CUBLAS_CALL(cublasSgemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
&alpha_f,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
&beta_f,
C.dptr_, half_datatype, C.stride_));
#if CUDA_VERSION >= 9000
SetCublasMathMode(blas_handle, previous_math_mode);
#endif
#else
LOG(FATAL) << "FP16 gemm requires CUDA version >= 7.5!";
#endif // CUDA_VERSION >= 7050
}


#define LINALG_GPU_BATCH_GEMM(fname, DType) \
template<> inline \
void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \
@@ -246,6 +299,39 @@ LINALG_GPU_BATCH_TRSM(DtrsmBatched, double)

#endif

/*!
* \brief Performs gemm, setting alpha and beta as appropriate for `req`.
*
* \param A the first operand of the gemm
* \param B the second operand of the gemm
* \param C the data to be assigned
* \tparam tA whether the `A` operand should be transposed first.
* \tparam tB whether the `B` operand should be transposed first.
* \tparam s the stream to perform the operation
* \param req the assignment request
*/
template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C,
bool tA, bool tB, Stream<xpu> *s,
mxnet::OpReqType req) {
using namespace mxnet;
switch (req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
break;
case kAddTo:
linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
break;
default:
LOG(FATAL) << "not reached";
}
}

//////////////////////////////// TRMM ////////////////////////////////////////////

// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation
5 changes: 5 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
@@ -926,6 +926,11 @@ def test_fullyconnected_with_type():
{'ctx': mx.cpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float64}},
{'ctx': mx.cpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float32}}]
check_consistency(sym, ctx_list)
# Sizes are divisible by 8 to test TensorCore on Volta GPU.
sym = mx.sym.FullyConnected(num_hidden=8, name='inner')
ctx_list = [{'ctx': mx.gpu(0), 'inner_data': (16, 24), 'type_dict': {'inner_data': np.float16}},
{'ctx': mx.cpu(0), 'inner_data': (16, 24), 'type_dict': {'inner_data': np.float32}}]
check_consistency(sym, ctx_list)


def test_activation_with_type():

0 comments on commit ff21e1f

Please sign in to comment.