Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Initial support for conv2d wgrad #10177

Merged
merged 8 commits into from
Feb 8, 2022
Merged

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 7, 2022

Add support for offloading conv2d_backward_weight op to cutlass. Note that, since in wgrad the K dimension of the implicit GEMM is very large (N * P * Q), split-k is required for reasonable performance. I've already implemented split-k support as well, but sending basic wgrad support first. This includes layout conversion for conv2d_backward_weight, pattern matching, codegen and test cases.

The key diff in this PR is the introduction of the accumulation data type, which is now separate from output type. Since wgrad requires fp32 accumulation in practice, even if the output is fp16 we always accumulate in fp32. Note that cuDNN also supports only fp32 accum + fp16 output for wgrad. For GEMM and other conv2d ops, the accumulation type is set to the output dtype for now.

Also fixes the bug of conv2d_backward_weight not supporting depth wise conv2d as reported in fd5915a#commitcomment-64547258

@comaniac @Laurawly

@@ -252,6 +252,8 @@ def select_op(
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
# Use fp32 accumulation for wgrad to align with cuDNN
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we force the accum dtype to be fp32 if wgrad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes lots of sense. I'm even wondering whether we should force accum dtype to be fp32 for all ConvKind instead of just wgrad.

Copy link
Member Author

@masahi masahi Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always using fp32 accum for other conv2d kind will probably bring perf regression (cuDNN also allows fp16 accumulation for fprop and dgrad). Ideally we should add accumulation_dtype to Conv2dAttr to guide that decision, I thought about doing that, but I realized that I have to change a lot of topi to take that into account.

Also we need to discuss what the interface for ToMixedPrecision should be if we want to allow changing accumulation dtype, right now we cannot flexibly change even output dtype @AndrewZhaoLuo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Make sense.
We discussed about the accum dtype before when @AndrewZhaoLuo was working on the ToMixedPrecision pass, but just like you pointed out, this will involve lots of TOPI changes.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, so if I'm understanding correctly for conv2d_winograd we want to accumulate to fp32 but if it's not winograd we are ok with accumulating to fp16.

ToMixedPrecision can configure accumulation and output dtypes for any call node but only using information from examining that node. I'm not sure implementation details like whether it's winograd can be transmitted here.

I will say on relay level all we care about is type checking imo so just get the output_dtype correct. For example, accumulate all you like in fp32 but internally just make sure the output fits the expected type written in interface. Perhaps extraneous cast here is bad but maybe we can repair it further down in topi-tir level.

Copy link
Member Author

@masahi masahi Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I agree that we can let out_dtype in conv2d essentially act like the accumulation data type, and add an explicit cast op if accum_dtype != out_dtype. That's fine as far as ToMixedPrecision pass goes, but for cutlass BYOC, we need to additionally pattern match against the added cast(fp32 -> fp16) op to know that this conv2d is fp32 accum -> fp16 out. And for cuDNN which is not implemented as BYOC, this doesn't work because all it sees is fp32 accum -> fp32 out conv2d. And cuDNN wgrad doesn't support such dtype combination.

Hmm yeah, so if I'm understanding correctly for conv2d_winograd we want to accumulate to fp32 but if it's not winograd we are ok with accumulating to fp16.

@AndrewZhaoLuo Here wgrad means "conv2d gradient with respect to weight", not winograd :)

@@ -252,6 +252,8 @@ def select_op(
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
# Use fp32 accumulation for wgrad to align with cuDNN
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes lots of sense. I'm even wondering whether we should force accum dtype to be fp32 for all ConvKind instead of just wgrad.

@masahi masahi merged commit 7fd73b2 into apache:main Feb 8, 2022
@masahi
Copy link
Member Author

masahi commented Feb 8, 2022

thanks @comaniac

ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* [CUTLASS] Add wgrad support (without split-k)

* run black

* wgrad tests now work under pytest

* dw conv2d properly supported for wgrad

* all tests work

* fixed for sm75

* cpplint

* fix conv2d grad test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants