Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support half precision sigmoid activation #378

Merged
merged 5 commits into from
Dec 22, 2021
Merged

Conversation

masahi
Copy link
Contributor

@masahi masahi commented Dec 12, 2021

Currently, trying to use sigmoid activation with half_t results in a compile error:

 error: no instance of overloaded function "cutlass::exp" matches the argument list                                                                                                                                  
            argument types are: (cutlass::half_t)       

@hwu36
Copy link
Collaborator

hwu36 commented Dec 13, 2021

Hi, @masahi ,

for half data type, it would be better use half2 arithmetic operations which has 2x throughput. You can take a look at tanh code here and here

@masahi
Copy link
Contributor Author

masahi commented Dec 13, 2021

@hwu36 Thanks for the suggestion. I've add an alternative, vectorized implementation using fast_tanh and the formula sigmoid(x) = (tanh(x/2) + 1) / 2.

But the faster one has non trivial accuracy difference with the TVM result:

Mismatched elements: 306397 / 524288 (58.4%)                                                                                                                                 
Max absolute difference: 0.000977                                                                                                                                            
Max relative difference: 14.945                                                                                                                                              
 x: array([[[[2.8320e-02, 2.2241e-01, 4.8828e-04, ..., 6.1035e-03,                                                                                                           
          3.2446e-01, 8.6133e-01],                                                                                                                                           
         [4.5166e-01, 9.9902e-01, 9.5166e-01, ..., 7.8125e-01,...                                                                                                            
 y: array([[[[2.8275e-02, 2.2217e-01, 2.7800e-04, ..., 5.9357e-03,                 
          3.2446e-01, 8.6133e-01],                                                    
         [4.5142e-01, 9.9902e-01, 9.5166e-01, ..., 7.8174e-01,...   

So I added an ifdef flag to switch between the two implementations.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 13, 2021

@shangz-ai, is it you that told me that you need to use fp32 to compute sigmoid for the accuracy?

@shangz-ai
Copy link
Contributor

@shangz-ai, is it you that told me that you need to use fp32 to compute sigmoid for the accuracy?

@hwu36 I'm not sure if we talked about this recently. But yes, in BERT, we usually use fp32 in softmax for better accuracy.

@masahi
Copy link
Contributor Author

masahi commented Dec 13, 2021

@hwu36 I'm not sure if we talked about this recently. But yes, in BERT, we usually use fp32 in softmax for better accuracy.

To be clear, this PR is about fp16 sigmoid, not softmax. I think doing sigmoid in fp16 is generally ok. I saw the accuracy issue only when using fast_tanh_op for sigmoid.

@masahi
Copy link
Contributor Author

masahi commented Dec 14, 2021

@hwu36 Can we get this in? I have an upcoming PR to TVM that could make use of this (sigmoid activation fusion with fp16 accum)

@hwu36
Copy link
Collaborator

hwu36 commented Dec 14, 2021

I will try to get it in this week. I need to go through our internal test process and I may need to make some changes to your PR directly.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 20, 2021

I am working on this one now. Most code is rewritten. You can wait a little bit before I commit.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 20, 2021

I just committed my change. It basically move exp math to fast_math.h. Would you please try it out?

I confess that I haven't done any correctness or performance testing of my code.

include/cutlass/fast_math.h Outdated Show resolved Hide resolved
@masahi
Copy link
Contributor Author

masahi commented Dec 20, 2021

There seems some accuracy issues. On a simple test case like d_shape = (16, 16, 32, 32), w_shape = (32, 16, 3, 3) (NCHW), the accuracy is good (I can assert equality with the TVM result with rtol, atol = 1e-5 while fast_tanh variant requires 1e-3. But on a larger instaince, d_shape = (16, 256, 64, 64), w_shape = (256, 256, 3, 3), it fails with

Mismatched elements: 4533726 / 16777216 (27%)
Max absolute difference: 0.0608
Max relative difference: 1.
 x: array([[[[1.0000e+00, 2.2375e-01, 9.9121e-01, ..., 5.4777e-05,
          7.9224e-02, 9.3115e-01],
         [6.4795e-01, 1.0000e+00, 1.2793e-01, ..., 1.0000e+00,...
 y: array([[[[1.0000e+00, 2.2656e-01, 9.9121e-01, ..., 5.4359e-05,
          8.0627e-02, 9.2920e-01],
         [6.4258e-01, 1.0000e+00, 1.3562e-01, ..., 1.0000e+00,...

And on an end-to-end imagenet model, the output is very different from the ones using either of my previous implementations (scalar and vector).

@masahi
Copy link
Contributor Author

masahi commented Dec 20, 2021

Sorry, on the d_shape = (16, 256, 64, 64), w_shape = (256, 256, 3, 3), my previous scalar implementation also seems to have the same accuracy difference with the TVM reference. But fast_exp results on an imagenet model look very different if you compare three versions below. The first two uses my previous implementations, they mostly agree.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 21, 2021

There must be something wrong and we need to bottom it out. We need to dump the value after every math step to check. Can you tell me which input to sigmoid can generate the most ridiculous output?

To debug, we can include this file and call dump_fragment to check. This example lists some usage. You can add if (blockIdx.x == && blockIdx.y == && threadIdx.x < ) to reduce the output size. Or you can add printf anywhere to dump anything you want?

Do you want to give it a try during your daytime or I can try it tomorrow?

@masahi
Copy link
Contributor Author

masahi commented Dec 21, 2021

It seems .raw() -> .to_half() change fixed the accuracy issue with imagenet. Does it make sense?

The performance on efficientnet v2 model improved from 8.26 to 8.05 msec.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 21, 2021

It seems .raw() -> .to_half() change fixed the accuracy issue with imagenet. Does it make sense?

Usually we don't hit the case when the size of Array is odd. What is your output channel number K? What is the value of Count here?

@masahi
Copy link
Contributor Author

masahi commented Dec 21, 2021

This is an imagnet model so there are lot of Ks. It seems all of them are even numbers.

What is the value of Count here?

It depends on the number of channels, 8, 4, 2 or 1
https://github.com/apache/tvm/blob/aa86dc030ccceb775c1b2955822af14f2544ebc9/python/tvm/contrib/cutlass/conv2d_operation.py#L199-L202

@hwu36
Copy link
Collaborator

hwu36 commented Dec 21, 2021

It seems all of them are even numbers.

Then, you are not supposed to hit the case where (N % 2 == 1).

Can you dump all the number of Count and the corresponding K that you used?

@masahi
Copy link
Contributor Author

masahi commented Dec 21, 2021

N there is not the output channel but the vector width of the epilogue right? As I said this could be 8, 4, 2 or 1 and I do have one layer that requires align1 and uses SiLu activation (which in turn uses sigmoid):

cutlass::epilogue::thread::LinearCombinationSilu<
  cutlass::half_t,
  1,
  cutlass::half_t,
  cutlass::half_t
>,

So that code path is indeed used once.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 21, 2021

N there is not the output channel but the vector width of the epilogue right?

Correct.

What is the output channel K when you use align1?

@masahi
Copy link
Contributor Author

masahi commented Dec 21, 2021

N there is not the output channel but the vector width of the epilogue right?

Correct.

What is the output channel K when you use align1?

K = 24. It is the first layer (C = 3).

@hwu36
Copy link
Collaborator

hwu36 commented Dec 21, 2021

It is the first layer (C = 3).

The regular conv that cutlass has now is functional for the small input channel such as C = 3, but not great in the performance.

K = 24

If the output channel is 24, you can set Count to 8. The output alignment is not necessarily the same as the input alignment. In your case, the alignment of A and B needs to be 1 since C = 3, but the output alignment can be 8 since K = 24.

@hwu36
Copy link
Collaborator

hwu36 commented Dec 22, 2021

Do we still have accuracy issue? Is the PR ready to be merged?

@masahi
Copy link
Contributor Author

masahi commented Dec 22, 2021

Do we still have accuracy issue? Is the PR ready to be merged?

No, the accuracy issue has been resolved and it is ready to go!! Thanks

@hwu36 hwu36 merged commit dceabd4 into NVIDIA:master Dec 22, 2021
@hwu36
Copy link
Collaborator

hwu36 commented Dec 22, 2021

Do you have new perf number after vectorized sigmoid based on vectorized exp? Do you need to change IsHeavy?

@masahi
Copy link
Contributor Author

masahi commented Dec 22, 2021

The performance on efficientnet v2 model improved from 8.26 to 8.05 msec using vectorized exp. No need to change IsHeavy (it is still true).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants