Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Add parallel split-k support to wgrad #10185

Merged
merged 4 commits into from
Feb 8, 2022

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 8, 2022

Building on #10177, this adds parallel split-k support to wgrad.

@comaniac @Laurawly @junrushao1994 @vinx13 @YuchenJin @hwu36 @manishucsd

Split-k is described in https://github.com/NVIDIA/cutlass/blob/master/media/docs/efficient_gemm.md#parallelized-reductions.
This is my first experience using split-k in cutlass or any other API. Wgrad is particularly interesting for split-k since the implicit gemm K dimension is really large in wgrad (N * P * Q where P and Q are the output H and W). Without split-k, wgrad on large spatial inputs is extremely slow.

For now, I'm not trying anything smart to pick the split-k parameter, instead we ask users to provide possible candidates. I tuned over [1, 4, 8, 16, 32, 64] below and that already showed excellent performance. The benchmark code is here.

Benchmark result against cuDNN. Note that currently there are non-trivial difference in cuDNN and TVM + cutlass outputs, especially for the larger batch size. I didn't find anything obviously wrong in the generated code and I gave up fixing accuracy difference at some point. Also note that difference is not due to parallel-split-k, even in a normal case the results were different (and actually improved after split-k lol).

The result showed that cutlass winning across the board (Profiler time vs cuDNN columns, but again, the results do not match exactly). However, there is a serious problem when cutlass wgrad + split-k kernels are called from TVM (TVM + CUTLASS column): Split-k requires large workspace, and the space requirement grows linearly with split-k-slices parameter. Right now we naively allocate the workspace on every cutlass kernel call on each run, while for cuDNN we have a simple workspace memory reuse mechanism implemented in (together with a thread local storage)

void ConvEntry::UpdateWorkspace(const size_t wsize) {
if (workspace_size < wsize) {
if (workspace != nullptr) {
CleanWorkspace();
}
workspace_size = wsize;
workspace = cuda_api->AllocWorkspace(device, workspace_size);
}
}
. So during benchmarking, we run compiled TVM + CUTLASS 100 times, and each time we are allocating the same workspace 🤦‍♂️ while cuDNN allocates only once during algo selection outside of the profiling loop. This is causing huge perf penalty in TVM + cutlass results above.

I attempted adding a simple workspace memory management in https://github.com/masahi/tvm/compare/cutlass-split-k...masahi:cutlass-workspace?expand=1, it kind of works in terms of the expected perf improvement. However, I'm getting segfault and other strange issues. I'm a bit confused as to what the right behavior should be for a thread local memory manager in the context of JIT- generated and compiled multiple translation units. Let me know if you have any thoughts on this issue.

Known issues and TODO

  • Accuracy alignment with cuDNN
  • Figure out workspace memory reuse
  • Split-k parameter selection strategy
  • Support split-k in GEMM and other conv2d kind

@masahi masahi changed the title [CUTLASS] Add split-k support to wgrad [CUTLASS] Add parallel split-k support to wgrad Feb 8, 2022
@hwu36
Copy link

hwu36 commented Feb 8, 2022

If you want to investigate accuracy issue, i suggest you compare both cutlass and cudnn with a naive fp64 or fp32 version.

commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
Author: Masahiro Masuda <[email protected]>
Date:   Tue Feb 8 10:43:11 2022 +0900

    pylint

commit ae2e718
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:51:52 2022 +0900

    Add split-k support for wgrad

    commit 43820d5
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 10:07:34 2022 +0900

        fix and add doc

    commit 446a95b
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 09:48:38 2022 +0900

        dw conv2d properly supported for wgrad

    commit adc4e22
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:32:42 2022 +0900

        fix overwriting template

    commit 040eab0
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:06:27 2022 +0900

        black

    commit e5a07c2
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:03:10 2022 +0900

        add reduction in profiler

    commit be89334
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 06:58:03 2022 +0900

        adding split k reduction to conv2d profiler

    commit ae09b0f
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 11:52:59 2022 +0900

        fixed conv2d_backward_weight typerel for dw conv2d

        commit 16fe531
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 12:59:22 2022 +0900

            wip

        commit 2167c25
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 04:22:19 2022 +0900

            fix conv2d type rel for depth wise and grouped conv2d

    commit 14b12e5
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 05:01:03 2022 +0900

        remove split_k.py

    commit b141271
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 04:48:21 2022 +0900

        workaround for invalid split_k_slice

    commit 6e4c7e1
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 02:43:58 2022 +0900

        support split k in profiler

    commit 2eb1cf4
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 02:31:03 2022 +0900

        improvement

    commit 0bce8f3
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 18:20:12 2022 +0900

        fixed for fp16 output

    commit 30df1bd
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 17:50:33 2022 +0900

        fp32 output works

    commit 7a51995
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 14:30:22 2022 +0900

        fix

    commit 4a383e2
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 14:05:24 2022 +0900

        update c++ codegen

    commit 6206e38
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 13:46:05 2022 +0900

        wip

    commit 0ece49b
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 03:05:21 2022 +0900

        wip

    commit 08a6147
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 13:10:21 2022 +0900

        test worked with fp32 output

    commit 084d5c4
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 12:35:18 2022 +0900

        fix compile error for fprop

    commit 31f2543
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 12:18:06 2022 +0900

        compiled

    commit c2098e7
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 11:11:43 2022 +0900

        wip

commit a145850
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:46:16 2022 +0900

    fixed for sm75

commit 6151506
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:32:46 2022 +0900

    all tests work

commit 041c094
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:19:09 2022 +0900

    dw conv2d properly supported for wgrad

commit 2191918
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 09:14:05 2022 +0900

    wgrad tests now work under pytest

commit 78f76df
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 07:31:54 2022 +0900

    run black

commit 0a82149
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 06:12:39 2022 +0900

    [CUTLASS] Add wgrad support (without split-k)
@jroesch
Copy link
Member

jroesch commented Feb 8, 2022

cc @mbs-octoml interesting example for perf work

@junrushao junrushao merged commit 35a7992 into apache:main Feb 8, 2022
@manishucsd
Copy link

manishucsd commented Feb 8, 2022

Hi Masa, This is amazing progress. Some questions on the known issues:

  • Accuracy alignment with cuDNN
    Can you share the size that has accuracy issues. Can you repro the accuracy issue in profiler?
  • Figure out workspace memory reuse
    Both cuDNN and CUTLASS offers similar get_workspace_size(...) API. Thus, I believe this part should be similar.
  • Split-k parameter selection strategy
    As we discussed in an another thread, we run sweeps to find the best split. You can cut down the sweep on k by using a simple analytic model.

@masahi
Copy link
Member Author

masahi commented Feb 8, 2022

Hi Manish,

Can you share the size that has accuracy issues. Can you repro the accuracy issue in profiler?

The benchmark result I linked above show accuracy difference in the last two columns. Most workload have some differences, except for some deeper layers in batch = 8 which showed exact match. It seems deeper layers, those having small spatial size and large channels, have generally less accuracy problems. The differences become much bigger for batch = 256. So it kind of works but not quite, it is very hard to debug. The profiler in cutlass doesn't report any accuracy problem, which is another mystery. It could be TVM's use of cuDNN wgrad having some issues.

Both cuDNN and CUTLASS offers similar get_workspace_size(...) API. Thus, I believe this part should be similar.

The issue is memory reuse across multiple calls. The way we integrate cuDNN and cutlass are significantly different. I tried to apply a similar memory management strategy we use for cuDNN to the JIT-generated cutlass, but as I said above I'm having strange issues.

As we discussed in an another thread, we run sweeps to find the best split. You can cut down the sweep on k by using a simple analytic model.

Yes, I haven't grokked your note in that thread. I just tried a dumb strategy in my benchmark and it already shows good performance. I didn't pursue perf improvement further, since the accuracy problem was more concerning.

@manishucsd
Copy link

On accuracy, floating point additions are not associative. The change the order can change the result. Parallel reduction does change the order of accumulation over GEMM-K (NPQ). Thus, some change between runs is expected. I don't have a guidance on what threshold to set in checking relative error.

I would take Haicheng's suggestions here and follow:

If you want to investigate accuracy issue, i suggest you compare both cutlass and cudnn with a naive fp64 or fp32 version.
Run FP32 wgrad with no split-k and compare both cutlass and cudnn against this golden reference.

CUTLASS profiler uses integer input to initialize tensors and matrices. This is to make the error checking easier. You can also use the CUTLASS profiler approach to make sure there are no functional error, i.e., try the operation on integer input.

@masahi
Copy link
Member Author

masahi commented Feb 8, 2022

Actually, accuracy difference was there even before I added parallel split-k to wgrad. And that the result got closer to cuDNN after adding split-k. So I believe the issue is not in parallel reduction, there is something off elsewhere. I have seen some workload where cuDNN uses cutlass's wgrad and reduction kernel, even in that case there was difference. Probably I should look at how TVM is using cuDNN wgrad first.

I haven't applied fp32 wgrad on large inputs, for small ones we used in the unit test, the result looked good. We also have an option of comparing against TVM native results, which I only looked briefly.

CUTLASS profiler uses integer input to initialize tensors and matrices. This is to make the error checking easier. You can also use the CUTLASS profiler approach to make sure there are no functional error, i.e., try the operation on integer input.

That's very interesting... I didn't know that. I can definitely try, thanks.

ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* [CUTLASS] Add split-k support to wgrad

commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
Author: Masahiro Masuda <[email protected]>
Date:   Tue Feb 8 10:43:11 2022 +0900

    pylint

commit ae2e718
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:51:52 2022 +0900

    Add split-k support for wgrad

    commit 43820d5
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 10:07:34 2022 +0900

        fix and add doc

    commit 446a95b
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 09:48:38 2022 +0900

        dw conv2d properly supported for wgrad

    commit adc4e22
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:32:42 2022 +0900

        fix overwriting template

    commit 040eab0
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:06:27 2022 +0900

        black

    commit e5a07c2
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:03:10 2022 +0900

        add reduction in profiler

    commit be89334
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 06:58:03 2022 +0900

        adding split k reduction to conv2d profiler

    commit ae09b0f
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 11:52:59 2022 +0900

        fixed conv2d_backward_weight typerel for dw conv2d

        commit 16fe531
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 12:59:22 2022 +0900

            wip

        commit 2167c25
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 04:22:19 2022 +0900

            fix conv2d type rel for depth wise and grouped conv2d

    commit 14b12e5
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 05:01:03 2022 +0900

        remove split_k.py

    commit b141271
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 04:48:21 2022 +0900

        workaround for invalid split_k_slice

    commit 6e4c7e1
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 02:43:58 2022 +0900

        support split k in profiler

    commit 2eb1cf4
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 02:31:03 2022 +0900

        improvement

    commit 0bce8f3
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 18:20:12 2022 +0900

        fixed for fp16 output

    commit 30df1bd
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 17:50:33 2022 +0900

        fp32 output works

    commit 7a51995
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 14:30:22 2022 +0900

        fix

    commit 4a383e2
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 14:05:24 2022 +0900

        update c++ codegen

    commit 6206e38
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 13:46:05 2022 +0900

        wip

    commit 0ece49b
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 03:05:21 2022 +0900

        wip

    commit 08a6147
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 13:10:21 2022 +0900

        test worked with fp32 output

    commit 084d5c4
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 12:35:18 2022 +0900

        fix compile error for fprop

    commit 31f2543
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 12:18:06 2022 +0900

        compiled

    commit c2098e7
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 11:11:43 2022 +0900

        wip

commit a145850
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:46:16 2022 +0900

    fixed for sm75

commit 6151506
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:32:46 2022 +0900

    all tests work

commit 041c094
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:19:09 2022 +0900

    dw conv2d properly supported for wgrad

commit 2191918
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 09:14:05 2022 +0900

    wgrad tests now work under pytest

commit 78f76df
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 07:31:54 2022 +0900

    run black

commit 0a82149
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 06:12:39 2022 +0900

    [CUTLASS] Add wgrad support (without split-k)

* pylint

* add more doc

* more doc clarification
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants