-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look at #17152 and cache primitive for backward pass. Thanks!
{ MKLDNN_ARG_DIFF_SRC, *out_mem.second }, | ||
}; | ||
|
||
stream->RegisterPrimArgs(bwd_pd, args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will change it to cache primitive pd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean here you need give a primitive not a primitive descriptor. Please check the definition of RegisterPrimArgs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right.
Although it is a pd, it call implicit constructor
inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
to get a primitive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will make it more clearly by adding constructor before used.
src/operator/nn/softmax.cc
Outdated
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
// It seems MKLDNN softmax doesn't support training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove this out-dated comments
src/operator/nn/softmax.cc
Outdated
inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); | ||
CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check will result in backward check fail.
Will recover this check and add another function for backward.
std::shared_ptr<mkldnn::softmax_backward> bwd_; | ||
}; | ||
|
||
typedef ParamOpSign<SoftmaxParam> MKLDNNSoftmaxSignature; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as L99?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly same, add extra one when rebase, will remove this line Thanks
void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType>& req, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<OpReqType>& req, | |
const std::vector<OpReqType> &req, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
auto data_mem = in_data[1].GetMKLDNNData(); | ||
auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data); | ||
|
||
auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Please check if you want to support req=kAddTo;
- softmax backward primitive should support in-place calculation so no need to create additional buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, create mem is used to support kAddTo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the original softmax backward support kAddTo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, original softmax bwd support kAddTo : SoftmaxGradCompute-->SoftmaxGrad -->KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
@TaoLv were all concerns resolved now? |
* add mkldnn softmax backward * add primitive cache for softmax bwd * fix preci failed test * rm duplicate line
* add mkldnn softmax backward * add primitive cache for softmax bwd * fix preci failed test * rm duplicate line
Description
add mkldnn softmax backward implementation
unitest pass
Should fix #13365
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
@PatricZhao @TaoLv