-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUTLASS] Initial support for conv2d wgrad #10177
Conversation
@@ -252,6 +252,8 @@ def select_op( | |||
lambda align: all([dim % align == 0 for dim in [IC, OC]]), | |||
use_3xtf32, | |||
profile_all_alignments, | |||
# Use fp32 accumulation for wgrad to align with cuDNN | |||
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we force the accum dtype to be fp32 if wgrad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes lots of sense. I'm even wondering whether we should force accum dtype to be fp32 for all ConvKind instead of just wgrad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always using fp32 accum for other conv2d kind will probably bring perf regression (cuDNN also allows fp16 accumulation for fprop and dgrad). Ideally we should add accumulation_dtype
to Conv2dAttr
to guide that decision, I thought about doing that, but I realized that I have to change a lot of topi to take that into account.
Also we need to discuss what the interface for ToMixedPrecision
should be if we want to allow changing accumulation dtype, right now we cannot flexibly change even output dtype @AndrewZhaoLuo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Make sense.
We discussed about the accum dtype before when @AndrewZhaoLuo was working on the ToMixedPrecision pass, but just like you pointed out, this will involve lots of TOPI changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah, so if I'm understanding correctly for conv2d_winograd we want to accumulate to fp32 but if it's not winograd we are ok with accumulating to fp16.
ToMixedPrecision
can configure accumulation and output dtypes for any call node but only using information from examining that node. I'm not sure implementation details like whether it's winograd can be transmitted here.
I will say on relay level all we care about is type checking imo so just get the output_dtype correct. For example, accumulate all you like in fp32 but internally just make sure the output fits the expected type written in interface. Perhaps extraneous cast here is bad but maybe we can repair it further down in topi-tir level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I agree that we can let out_dtype
in conv2d essentially act like the accumulation data type, and add an explicit cast op if accum_dtype != out_dtype
. That's fine as far as ToMixedPrecision
pass goes, but for cutlass BYOC, we need to additionally pattern match against the added cast(fp32 -> fp16)
op to know that this conv2d is fp32 accum -> fp16 out. And for cuDNN which is not implemented as BYOC, this doesn't work because all it sees is fp32 accum -> fp32 out conv2d. And cuDNN wgrad doesn't support such dtype combination.
Hmm yeah, so if I'm understanding correctly for conv2d_winograd we want to accumulate to fp32 but if it's not winograd we are ok with accumulating to fp16.
@AndrewZhaoLuo Here wgrad means "conv2d gradient with respect to weight", not winograd :)
@@ -252,6 +252,8 @@ def select_op( | |||
lambda align: all([dim % align == 0 for dim in [IC, OC]]), | |||
use_3xtf32, | |||
profile_all_alignments, | |||
# Use fp32 accumulation for wgrad to align with cuDNN | |||
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes lots of sense. I'm even wondering whether we should force accum dtype to be fp32 for all ConvKind instead of just wgrad.
thanks @comaniac |
* [CUTLASS] Add wgrad support (without split-k) * run black * wgrad tests now work under pytest * dw conv2d properly supported for wgrad * all tests work * fixed for sm75 * cpplint * fix conv2d grad test
Add support for offloading
conv2d_backward_weight
op to cutlass. Note that, since in wgrad the K dimension of the implicit GEMM is very large (N * P * Q
), split-k is required for reasonable performance. I've already implemented split-k support as well, but sending basic wgrad support first. This includes layout conversion forconv2d_backward_weight
, pattern matching, codegen and test cases.The key diff in this PR is the introduction of the accumulation data type, which is now separate from output type. Since wgrad requires fp32 accumulation in practice, even if the output is fp16 we always accumulate in fp32. Note that cuDNN also supports only fp32 accum + fp16 output for wgrad. For GEMM and other conv2d ops, the accumulation type is set to the output dtype for now.
Also fixes the bug of
conv2d_backward_weight
not supporting depth wise conv2d as reported in fd5915a#commitcomment-64547258@comaniac @Laurawly