-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix batchnorm problem with sparse matrices when fix_gamma=True #11656
Conversation
@marcoabreu Seems like there's no build triggered for this: http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/activity/?branch=PR-11656. Can you take a look? |
Link to the issue: #11655 |
@anirudh2290 That's a duplicate issue of #11654. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/operator/batch_norm_v1.cc
Outdated
@@ -89,6 +89,9 @@ the output. It is often used during inference. | |||
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, | |||
then set ``gamma`` to 1 and its gradient to 0. | |||
|
|||
There's no sparse support for this operator, and will exhibit problematic behavior if used with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "There's no sparse support for this operator, and will" -> "There's no sparse support for this operator. It will"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -233,7 +233,7 @@ def check_batchnorm_training(stype): | |||
mx.nd.array(beta).tostype(stype)] | |||
mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] | |||
|
|||
test = mx.symbol.BatchNorm(data, fix_gamma=True) | |||
test = mx.symbol.BatchNorm(data, fix_gamma=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate why you are modifying a test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we had the conversation in Lai's PR yesterday, and I've created the related issue #11647
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this mkl test is to test that we can fallback to sparse matrices correctly when using MKLDNN under the legal cases for sparse matrices. The logics of exclusion of illegal cases for sparse matrices(fix_gamma=True), which is shared by both USE_MKLDNN=1 and USE_MKLDNN=0 situtaions, were already tested by the test_batchnorm_fallback test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, sorry, got a lot of PRs swirling around my head. Let me rephrase to ensure I got it right: This means that the test case with gamma=True is basically invalid until that issue has been resolved and that's why you are changing it, correct?
If yes, we're good to go. Sorry for the caused inconvenience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you got it, that's why we're only testing the necessary legal case here. Please merge this once you feel good to do so, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to request the review from another committer to assess the backend changes if you don't mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. @zheng-da @eric-haibin-lin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I think @anirudh2290 already approved the previous #11631
97dec9a
to
7613011
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good to me.
src/operator/nn/batch_norm.cc
Outdated
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, | ||
in_attrs, out_attrs); | ||
} | ||
if ((common::ContainsStorageType(*in_attrs, kRowSparseStorage) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change it to !common::ContainsOnlyStorage(kDefault)
, in case we add more stype in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
7613011
to
5a75911
Compare
5a75911
to
a31ee76
Compare
Description
As title.
Checklist
Essentials
Changes
Comments
Resurrection of #11631.