Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Enabling BRGEMM FullyConnected based on shapes (#20533)
Browse files Browse the repository at this point in the history
* Enable brgemm based on input info

* fix sanity

* Review fixes

* Change function name

* Fix typo

* Align variable assignments

* Fix review

* use const reference
bgawrych authored Sep 2, 2021
1 parent 2b9607a commit 59e9b94
Showing 3 changed files with 24 additions and 12 deletions.
6 changes: 3 additions & 3 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
@@ -298,9 +298,9 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: Int ```(default=-1)```
- Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape.

* MXNET_MKLDNN_DISABLE_BRGEMM_FC
- Values: 0, 1 ```(default=1)```
- Flag which disables BRGEMM kernels in FullyConnected executed with MKLDNN support - Should only be set to 0 if your model has constant input shapes or FullyConnected is calculated with large tensors. Supported on machines with AVX512-VNNI.
* MXNET_MKLDNN_FORCE_FC_AB_FORMAT
- Values: 0, 1 ```(default=0)```
- If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support which requires special weights format.

* MXNET_ENFORCE_DETERMINISM
- Values: 0(false) or 1(true) ```(default=0)```
19 changes: 15 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
@@ -286,17 +286,28 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype = -1) {
inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, size_t batch_size) {
// Conditions based on measurement results done on CLX8280
// https://github.com/apache/incubator-mxnet/pull/20533
return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++)
dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
const bool brgemm_disabled = dmlc::GetEnv("MXNET_MKLDNN_DISABLE_BRGEMM_FC", true);
if (dims.size() == 2 && brgemm_disabled) {
format = mkldnn::memory::format_tag::ab;
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
format = mkldnn::memory::format_tag::ab;
}
}

return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
11 changes: 6 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
@@ -41,10 +41,11 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
const NDArray& weight,
const NDArray* bias,
const mkldnn::memory::desc& out_md) {
auto data_md = GetMemDesc(data);
auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight, mshadow::kInt8)
: GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
auto data_md = GetMemDesc(data);
auto weight_md = full_param.mkldnn_param.quantized
? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
: GetFCWeightDesc(weight, data.shape()[0]);
auto propagation =
is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;

@@ -107,7 +108,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
@@ -121,7 +122,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {

0 comments on commit 59e9b94

Please sign in to comment.