-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PTX-MMA] Add full PTX MMA code generation support #9909
Conversation
c97319a
to
d229a98
Compare
d229a98
to
4da1629
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor issues otherwise LGTM. There is some (unrelated) test errors on CI, could you try pushing again to restart the CI?
src/target/source/ptx_mma.h
Outdated
namespace tvm { | ||
namespace codegen { | ||
|
||
std::string PrintPTXAssembly(const std::string& shape, const std::string& A_layout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe PrintMMAAssembly
would be a better name
golden = np.matmul(A_np.astype("float64"), B_np.astype("float64").T) | ||
|
||
C_numpy = C_tvm.numpy() | ||
from tvm import testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not needed as tvm.testing
is already imported at the beginning
src/target/source/ptx_mma.cc
Outdated
/* | ||
* TODO: add mma.m16n8k128 | ||
*/ | ||
return ""; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return ""; | |
ICHECK(0); | |
throw; |
if this is unreachable, just raises an error
4da1629
to
2664115
Compare
for i in range(4): | ||
Accum[i] = T.float32(0) | ||
|
||
for mma_multi_a_col in T.vectorized(4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I wonder if you could elaborate more on the necessity of the declarations of MultiA
, MultiB
and Accum
buffers here. Do buffers like A
, B
and C
not work within the MMA assembly code generated below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To use MMA instructions, the multiplicands and accumulator should be placed in registers, otherwise, the behavior is undefined. I have tried to use global buffers (e.g., A
, B
, C
) to invoke MMA instructions, and the results are all wrong.
MultiA[mma_multi_a_col] = A[ | ||
(tx % 32) // 4 + mma_multi_a_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2 | ||
] | ||
for mma_multi_b_col in T.vectorized(4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can combine the three loops to initialize MulitA
MultiB
and possibly Accum
given the loop invariant are the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it is more clear to make them separate because for people who are not familiar with CUDA or MMA, they can tell that the load of MultiA
, MultiB
, and the initialization of Accum
are decoupled, which is also in accord with the pattern of the code generated by TVM.
"fp16", | ||
"fp32", | ||
MultiA, | ||
0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the use of MultiA
MultiB
and Accum
make the bias/offset here unnecessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe there is another way to implement the interface. I followed the existing manner of tvm_mma_sync
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally it is necessary if the buffer larger than required by mma.
|
||
A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") | ||
B_np = np.random.uniform(-1, 1, [8, 8]).astype("float16") | ||
C_np = np.random.uniform(-1, 1, [16, 8]).astype("float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should't the value of C_np
be zeros?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's OK to set C_np
to random values, although the most standard way is to set C_np
to zeros. The results are not affected by the initial value of C_np
because the accumulators are always initialized to zeros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am worried about the implication may confuse people.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed to np.zeros
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that random initialization does follow the convention we have in the tvm repo, so i dont think it confuses anybody. changing to zeros should good too, so i dont have strong opinion
2664115
to
54c5ca2
Compare
Some tests are failing (probably not relevant to this PR). Retriggering |
Failed again. @KnowingNothing would you mind checking the unittests also on your side? |
It also failed on my local machine. $ pytest tests/python/frontend/pytorch/qnn_test.py::test_serialized_modules
enabled targets: llvm; llvm -device=arm_cpu; cuda; cuda -model=unknown -libs=cudnn; nvptx; opencl; opencl -device=mali,aocl_sw_emu; opencl -device=intel_graphics
pytest marker:
============================================================== test session starts ===============================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/zchno/TVM/tvm-mirror-pr
collected 1 item
tests/python/frontend/pytorch/qnn_test.py Fatal Python error: Aborted
Current thread 0x00007f149c7e1740 (most recent call first):
File "/home/zchno/venv/prime/lib/python3.8/site-packages/torch/jit/_serialization.py", line 161 in load
File "/home/zchno/TVM/tvm-mirror-pr/tests/python/frontend/pytorch/qnn_test.py", line 513 in test_serialized_modules
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/python.py", line 183 in pytest_pyfunc_call
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/python.py", line 1641 in runtest
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 162 in pytest_runtest_call
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 255 in <lambda>
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 311 in from_call
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 254 in call_runtest_hook
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 215 in call_and_report
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 126 in runtestprotocol
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 348 in pytest_runtestloop
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 323 in _main
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 269 in wrap_session
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 316 in pytest_cmdline_main
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/config/__init__.py", line 162 in main
File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/config/__init__.py", line 185 in console_main
File "/home/zchno/venv/prime/bin/pytest", line 8 in <module>
Aborted (core dumped) |
@KnowingNothing Can you try rebasing and testing again? |
54c5ca2
to
e568cc0
Compare
I checked the unit test and confirmed it is caused by this commit. It seems the error only happens when |
@vinx13 My experience with std::regex is overwhelmingly negative. If it's the source of these bugs, let's consider other alternatives |
CC: @jinhongyii |
e568cc0
to
611a7ec
Compare
I tried to replace |
Thanks! This is huge |
…y to warp memory (#10855) We already have PTX mma and mma.sp builtin support in #9909 and #10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma. This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
…y to warp memory (apache#10855) We already have PTX mma and mma.sp builtin support in apache#9909 and apache#10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma. This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
…y to warp memory (apache#10855) We already have PTX mma and mma.sp builtin support in apache#9909 and apache#10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma. This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
This change adds full (although not all) PTX MMA code generation support for three generations of Tensor Core, including Volta, Turing, and Ampere. The generation logic is mainly implemented in ptx_mma.cc and should have no major influence on existing code. A test file is also provided in tests/python/unittest/test_tir_ptx_mma.py. Here is a list of limitations and further improvement is possible:
mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc
for uint1 andmma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc
for int1. This may not be a perfect decision.