Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][DLight] Introduce Specific Rule for RMSNorm #16338

Merged
merged 6 commits into from
Jan 8, 2024
Merged

Conversation

Celve
Copy link
Contributor

@Celve Celve commented Jan 3, 2024

This PR introduces a specific rule for RMS norm in dlight, which allow the norm to now perform on par with CUTLASS standards.

New test case for the rule has also been added.

Below is the performance comparison for Llama-2-7b-chat-hf-q4f16_1.

New:

======================= Encoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
fused_fused_decode4_NT_matmul3                    0.3117      32      9.9759            33.93     48.67           152.4731          (22016, 512), (22016, 128), (1, 6, 4096), (1, 6, 22016)
fused_fused_decode5_fused_NT_matmul4_add          0.2671      32      8.5482            29.08     24.41           89.2261           (4096, 1376), (4096, 344), (1, 6, 11008), (1, 6, 4096), (1, 6, 4096)
fused_fused_decode2_NT_matmul                     0.2037      32      6.5183            22.17     27.19           130.3419          (12288, 512), (12288, 128), (1, 6, 4096), (1, 6, 12288)
fused_fused_decode3_fused_NT_matmul2_add          0.1013      32      3.2422            11.03     9.14            88.1027           (4096, 512), (4096, 128), (1, 6, 4096), (1, 6, 4096), (1, 6, 4096)
fused_NT_matmul1_divide_maximum_minimum_cast      0.0070      32      0.2245            0.76      0.10            13.6735           (1, 32, 6, 128), (1, 32, 6, 128), (1, 1, 6, 6), (1, 32, 6, 6)
rms_norm                                          0.0027      65      0.1752            0.60      0.10            36.7974           (1, 6, 4096), (4096,), (1, 6, 4096)
matmul4                                           0.0049      32      0.1571            0.53      0.10            19.0895           (1, 32, 6, 6), (1, 32, 6, 128), (1, 32, 6, 128)
transpose4                                        0.0022      64      0.1413            0.48      0.09            41.4662           (1, 6, 32, 128), (1, 32, 6, 128)
fused_fused_decode1_fused_NT_matmul5_cast2        0.1066      1       0.1066            0.36      70.44           645.4770          (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
fused_softmax_cast1                               0.0030      32      0.0955            0.32      0.01            2.1578            (1, 32, 6, 6), (1, 32, 6, 6)
transpose5                                        0.0022      32      0.0716            0.24      0.09            40.9420           (1, 32, 6, 128), (1, 6, 32, 128)
fused_transpose4                                  0.0022      32      0.0706            0.24      0.09            41.5244           (1, 6, 32, 128), (1, 32, 6, 128)
fused_split_silu_multiply                         0.0020      32      0.0654            0.22      0.38            180.6251          (1, 6, 22016), (1, 6, 11008)
extend_te                                         0.0022      1       0.0022            0.01      0.00            0.0600            (1, 1, 6, 6), (1, 1, 6, 6)
fused_fused_decode1_take                          0.0019      1       0.0019            0.01      70.36           35502.6806        (32000, 512), (32000, 128), (6,), (6, 4096)
slice                                             0.0019      1       0.0019            0.01      0.05            28.1075           (1, 6, 4096), (1, 1, 4096)
Total time: 29.3983 ms

======================= Decoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
fused_fused_decode4_NT_matmul9                    0.0747      32      2.3906            36.75     48.42           633.0044          (22016, 512), (22016, 128), (1, 1, 4096), (1, 1, 22016)
fused_fused_decode5_fused_NT_matmul10_add1        0.0445      32      1.4238            21.89     24.22           531.6822          (4096, 1376), (4096, 344), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
fused_fused_decode2_NT_matmul6                    0.0437      32      1.3986            21.50     27.03           603.9670          (12288, 512), (12288, 128), (1, 1, 4096), (1, 1, 12288)
fused_fused_decode3_fused_NT_matmul8_add1         0.0184      32      0.5892            9.06      9.02            478.6109          (4096, 512), (4096, 128), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
rms_norm1                                         0.0025      65      0.1618            2.49      0.02            9.1963            (1, 1, 4096), (4096,), (1, 1, 4096)
transpose4                                        0.0022      64      0.1416            2.18      0.11            48.2909           (1, 7, 32, 128), (1, 32, 7, 128)
fused_fused_decode1_fused_NT_matmul5_cast2        0.1066      1       0.1066            1.64      70.44           645.3116          (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
matmul10                                          0.0025      32      0.0786            1.21      0.06            25.0127           (1, 32, 1, 7), (1, 32, 7, 128), (1, 32, 1, 128)
fused_softmax2_cast4                              0.0024      32      0.0752            1.16      0.00            0.5324            (1, 32, 1, 7), (1, 32, 1, 7)
fused_NT_matmul7_divide2_maximum1_minimum1_cast3  0.0023      32      0.0744            1.14      0.06            26.6235           (1, 32, 1, 128), (1, 32, 7, 128), (1, 1, 1, 7), (1, 32, 1, 7)
fused_split1_silu1_multiply1                      0.0019      32      0.0606            0.93      0.06            32.4741           (1, 1, 22016), (1, 1, 11008)
full                                              0.0020      1       0.0020            0.03      0.00            0.0066            (1, 1, 1, 7)
fused_fused_decode1_take1                         0.0018      1       0.0018            0.03      70.32           37132.5438        (32000, 512), (32000, 128), (1,), (1, 4096)
Total time: 6.5048 ms

Old:

======================= Encoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
fused_fused_decode4_NT_matmul3                    0.3116      32      9.9706            33.84     48.67           152.5537          (22016, 512), (22016, 128), (1, 6, 4096), (1, 6, 22016)
fused_fused_decode5_fused_NT_matmul4_add          0.2670      32      8.5452            29.00     24.41           89.2574           (4096, 1376), (4096, 344), (1, 6, 11008), (1, 6, 4096), (1, 6, 4096)
fused_fused_decode2_NT_matmul                     0.2037      32      6.5187            22.13     27.19           130.3332          (12288, 512), (12288, 128), (1, 6, 4096), (1, 6, 12288)
fused_fused_decode3_fused_NT_matmul2_add          0.1013      32      3.2421            11.00     9.14            88.1061           (4096, 512), (4096, 128), (1, 6, 4096), (1, 6, 4096), (1, 6, 4096)
rms_norm                                          0.0038      65      0.2440            0.83      0.10            26.4202           (1, 6, 4096), (4096,), (1, 6, 4096)
fused_NT_matmul1_divide1_maximum_minimum_cast     0.0070      32      0.2248            0.76      0.10            13.6538           (1, 32, 6, 128), (1, 32, 6, 128), (1, 1, 6, 6), (1, 32, 6, 6)
matmul                                            0.0049      32      0.1569            0.53      0.10            19.1045           (1, 32, 6, 6), (1, 32, 6, 128), (1, 32, 6, 128)
transpose                                         0.0022      64      0.1413            0.48      0.09            41.4722           (1, 6, 32, 128), (1, 32, 6, 128)
fused_fused_decode1_fused_NT_matmul5_cast2        0.1068      1       0.1068            0.36      70.44           644.0582          (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
fused_softmax1_cast1                              0.0029      32      0.0918            0.31      0.01            2.2452            (1, 32, 6, 6), (1, 32, 6, 6)
transpose1                                        0.0023      32      0.0721            0.24      0.09            40.6511           (1, 32, 6, 128), (1, 6, 32, 128)
fused_split_silu_multiply                         0.0022      32      0.0711            0.24      0.38            166.0295          (1, 6, 22016), (1, 6, 11008)
fused_transpose                                   0.0022      32      0.0706            0.24      0.09            41.5205           (1, 6, 32, 128), (1, 32, 6, 128)
extend_te                                         0.0022      1       0.0022            0.01      0.00            0.0600            (1, 1, 6, 6), (1, 1, 6, 6)
fused_fused_decode1_take                          0.0020      1       0.0020            0.01      70.36           34944.1585        (32000, 512), (32000, 128), (6,), (6, 4096)
slice                                             0.0019      1       0.0019            0.01      0.05            28.0399           (1, 6, 4096), (1, 1, 4096)
Total time: 29.4621 ms

======================= Decoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
fused_fused_decode4_NT_matmul9                    0.0747      32      2.3900            36.35     48.42           633.1580          (22016, 512), (22016, 128), (1, 1, 4096), (1, 1, 22016)
fused_fused_decode5_fused_NT_matmul10_add1        0.0445      32      1.4231            21.64     24.22           531.9347          (4096, 1376), (4096, 344), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
fused_fused_decode2_NT_matmul6                    0.0437      32      1.3984            21.27     27.03           604.0569          (12288, 512), (12288, 128), (1, 1, 4096), (1, 1, 12288)
fused_fused_decode3_fused_NT_matmul8_add1         0.0184      32      0.5898            8.97      9.02            478.0792          (4096, 512), (4096, 128), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
rms_norm1                                         0.0036      65      0.2338            3.56      0.02            6.3626            (1, 1, 4096), (4096,), (1, 1, 4096)
transpose                                         0.0022      64      0.1414            2.15      0.11            48.3282           (1, 7, 32, 128), (1, 32, 7, 128)
fused_fused_decode1_fused_NT_matmul5_cast2        0.1069      1       0.1069            1.63      70.44           643.6275          (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
matmul3                                           0.0025      32      0.0791            1.20      0.06            24.8680           (1, 32, 1, 7), (1, 32, 7, 128), (1, 32, 1, 128)
fused_softmax2_cast4                              0.0024      32      0.0752            1.14      0.00            0.5324            (1, 32, 1, 7), (1, 32, 1, 7)
fused_NT_matmul7_divide2_maximum1_minimum1_cast3  0.0023      32      0.0743            1.13      0.06            26.6699           (1, 32, 1, 128), (1, 32, 7, 128), (1, 1, 1, 7), (1, 32, 1, 7)
fused_split1_silu1_multiply1                      0.0019      32      0.0596            0.91      0.06            33.0057           (1, 1, 22016), (1, 1, 11008)
fused_fused_decode1_take1                         0.0019      1       0.0019            0.03      70.32           35296.3286        (32000, 512), (32000, 128), (1,), (1, 4096)
full                                              0.0019      1       0.0019            0.03      0.00            0.0067            (1, 1, 1, 7)
Total time: 6.5756 ms

@vinx13
Copy link
Member

vinx13 commented Jan 5, 2024

Does the result in Old refer to dlight or cutlass?

@Celve
Copy link
Contributor Author

Celve commented Jan 6, 2024

Does the result in Old refer to dlight or cutlass?

It refers to dlight.

@vinx13
Copy link
Member

vinx13 commented Jan 6, 2024

thanks for the clarification, do we have numbers comparing with cutlass?

@Celve
Copy link
Contributor Author

Celve commented Jan 6, 2024

thanks for the clarification, do we have numbers comparing with cutlass?

cutlass:

======================= Encoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
fused_fused_decode4_NT_matmul2                    0.3124      32      9.9961            34.48     48.67           152.1651          (22016, 512), (22016, 128), (1, 6, 4096), (1, 6, 22016)
fused_fused_decode5_fused_NT_matmul3_add          0.2669      32      8.5402            29.46     24.41           89.3100           (4096, 1376), (4096, 344), (1, 6, 11008), (1, 6, 4096), (1, 6, 4096)
fused_fused_decode2_NT_matmul                     0.2039      32      6.5247            22.51     27.19           130.2134          (12288, 512), (12288, 128), (1, 6, 4096), (1, 6, 12288)
fused_fused_decode3_fused_NT_matmul1_add          0.1013      32      3.2423            11.18     9.14            88.0998           (4096, 512), (4096, 128), (1, 6, 4096), (1, 6, 4096), (1, 6, 4096)
fused_relax_nn_attention_cutlass1                 0.0057      32      0.1828            0.63      0.20            33.3827           (1, 6, 32, 128), (1, 6, 32, 128), (1, 6, 32, 128), (8192,), (1, 6, 32, 128)
split                                             0.0052      32      0.1665            0.57      0.28            52.7735           (1, 6, 12288), (1, 6, 4096), (1, 6, 4096), (1, 6, 4096)
fused_rms_norm_cutlass                            0.0025      65      0.1644            0.57      0.10            39.2137           (1, 6, 4096), (4096,), (1, 6, 4096)
fused_fused_decode1_fused_NT_matmul4_cast         0.1064      1       0.1064            0.37      70.44           646.5570          (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
fused_split1_silu_multiply                        0.0020      32      0.0626            0.22      0.38            188.7035          (1, 6, 22016), (1, 6, 11008)
fused_fused_decode1_take                          0.0019      1       0.0019            0.01      70.36           35502.7417        (32000, 512), (32000, 128), (6,), (6, 4096)
slice                                             0.0018      1       0.0018            0.01      0.05            29.3516           (1, 6, 4096), (1, 1, 4096)
Total time: 28.9899 ms

======================= Decoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
fused_fused_decode4_NT_matmul7                    0.0747      32      2.3908            37.73     48.42           632.9457          (22016, 512), (22016, 128), (1, 1, 4096), (1, 1, 22016)
fused_fused_decode5_fused_NT_matmul8_add1         0.0454      32      1.4527            22.93     24.22           521.0961          (4096, 1376), (4096, 344), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
fused_fused_decode2_NT_matmul5                    0.0437      32      1.3981            22.06     27.03           604.1956          (12288, 512), (12288, 128), (1, 1, 4096), (1, 1, 12288)
fused_fused_decode3_fused_NT_matmul6_add1         0.0182      32      0.5829            9.20      9.02            483.7281          (4096, 512), (4096, 128), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
fused_relax_nn_attention1_cutlass1                0.0057      32      0.1822            2.88      0.13            22.7810           (1, 1, 32, 128), (1, 7, 32, 128), (1, 7, 32, 128), (8192,), (1, 1, 32, 128)
fused_rms_norm1_cutlass                           0.0024      65      0.1580            2.49      0.02            9.4179            (1, 1, 4096), (4096,), (1, 1, 4096)
fused_fused_decode1_fused_NT_matmul4_cast         0.1066      1       0.1066            1.68      70.44           645.1684          (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
fused_split2_silu1_multiply1                      0.0020      32      0.0631            1.00      0.06            31.1784           (1, 1, 22016), (1, 1, 11008)
fused_fused_decode1_take1                         0.0019      1       0.0019            0.03      70.32           36989.5363        (32000, 512), (32000, 128), (1,), (1, 4096)
Total time: 6.3364 ms

@Hzfengsy Hzfengsy merged commit 8e54a9e into apache:unity Jan 8, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants