From f72360df5f951149f5aa6d9e8a37d1712585955f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 06:20:32 +0900 Subject: [PATCH 1/8] registor ptx builtin for async copy --- include/tvm/tir/builtin.h | 4 ++++ src/tir/op/builtin.cc | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index b166b16b7721..5a166b2080e4 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -632,6 +632,10 @@ TVM_DLL const Op& ptx_mma_sp(); */ TVM_DLL const Op& ptx_ldmatrix(); +TVM_DLL const Op& ptx_cp_async(); +TVM_DLL const Op& ptx_commit_group(); +TVM_DLL const Op& ptx_wait_group(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 4e8d83dd32df..0415d1bbec9e 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -247,6 +247,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); From 92edef5a18f3f7bb6fb1fb6ae1f156d548969123 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 06:31:41 +0900 Subject: [PATCH 2/8] add basic codegen --- src/target/source/codegen_cuda.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d4ec536fb001..3e19d23e5647 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -821,6 +821,17 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string smem_elem_offset = this->PrintExpr(op->args[6]); this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, smem_ptr, smem_elem_offset); + } else if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + std::string size = this->PrintExpr(op->args[2]); + this->stream << "__asm__ __volatile__(\"cp.async.ca.shared.global [" + dst + "], [" + src + + "], " + size + "\");\n"; + } else if (op->op.same_as(builtin::ptx_commit_group())) { + this->stream << "__asm__ __volatile__(\"cp.async.commit_group\");\n"; + } else if (op->op.same_as(builtin::ptx_wait_group())) { + std::string N = this->PrintExpr(op->args[0]); + this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + "\");\n"; } else { CodeGenC::VisitExpr_(op, os); } From 915991e3895cbab180ddfa745cad60a53c3b6d47 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 08:12:01 +0900 Subject: [PATCH 3/8] add test --- src/target/source/codegen_cuda.cc | 9 +-- src/target/source/ptx.cc | 18 ++++++ src/target/source/ptx.h | 5 ++ .../python/unittest/test_tir_ptx_cp_async.py | 57 +++++++++++++++++++ 4 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 tests/python/unittest/test_tir_ptx_cp_async.py diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 3e19d23e5647..cb78b37dd71c 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -823,10 +823,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { smem_ptr, smem_elem_offset); } else if (op->op.same_as(builtin::ptx_cp_async())) { std::string dst = this->PrintExpr(op->args[0]); - std::string src = this->PrintExpr(op->args[1]); - std::string size = this->PrintExpr(op->args[2]); - this->stream << "__asm__ __volatile__(\"cp.async.ca.shared.global [" + dst + "], [" + src + - "], " + size + "\");\n"; + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group\");\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 02a98ffbbabd..84ee75c9e618 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -638,5 +638,23 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type return asm_code; } +std::string PrintCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, size_t bytes) { + + // unsigned int addr; + // __asm__ __volatile__( + // "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + // : "=r"(addr) + // : "l"((void *)({smem_addr})) + // ); + + // this->stream << "__asm__ __volatile__(\"cp.async.ca.shared.global [" + dst + dst_offset + + // "], [" + src + src_offset + "], " + size + "\");\n"; + + return ""; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index c4255d737ad0..654d1477a700 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -79,6 +79,11 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type const std::string& smem_ptr, const std::string& smem_elem_offset); +std::string PrintCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, size_t bytes); + } // namespace codegen } // namespace tvm diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py new file mode 100644 index 000000000000..60aba7f3986e --- /dev/null +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +@T.prim_func +def ptx_cp_async( + A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(16, 128), "float16"] +) -> None: + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") + + for i in range(16): + T.evaluate( + T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16") + ) + + T.ptx_wait_group(0) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + +@tvm.testing.requires_cuda +def test_ptx_cp_async(): + f = ptx_cp_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 80: + # Require at least SM80 + return + +if __name__ == "__main__": + test_ptx_cp_async() From 0ab57fcf1f4095bdf1e95cb8429eb195f73ef3e9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 08:25:11 +0900 Subject: [PATCH 4/8] update codegen --- src/target/source/ptx.cc | 34 +++++++++++++++++++++------------- src/target/source/ptx.h | 2 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 84ee75c9e618..405ee68aac8c 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -641,19 +641,27 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& shared_elem_offset, const std::string& global_ptr, - const std::string& global_elem_offset, size_t bytes) { - - // unsigned int addr; - // __asm__ __volatile__( - // "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - // : "=r"(addr) - // : "l"((void *)({smem_addr})) - // ); - - // this->stream << "__asm__ __volatile__(\"cp.async.ca.shared.global [" + dst + dst_offset + - // "], [" + src + src_offset + "], " + size + "\");\n"; - - return ""; + const std::string& global_elem_offset, const std::string& bytes) { + std::string asm_code = R"( + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)({smem_addr})) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%1], [%2], %3;\n" + :"r"(addr), "l"(global_ptr), "n"(bytes) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes} ", bytes); + asm_code = replacer.rewrite(asm_code); + return asm_code; } } // namespace codegen diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index 654d1477a700..fd131822ff9f 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -82,7 +82,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& shared_elem_offset, const std::string& global_ptr, - const std::string& global_elem_offset, size_t bytes); + const std::string& global_elem_offset, const std::string& bytes); } // namespace codegen } // namespace tvm From 405d13673dd4774b5cdc82e7bea979d806929433 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 09:16:50 +0900 Subject: [PATCH 5/8] wip --- src/target/source/codegen_cuda.cc | 4 ++-- src/target/source/ptx.cc | 7 ++++--- .../python/unittest/test_tir_ptx_cp_async.py | 20 +++++++++++++++---- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index cb78b37dd71c..7459d4c250ba 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -829,10 +829,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string size = this->PrintExpr(op->args[4]); this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); } else if (op->op.same_as(builtin::ptx_commit_group())) { - this->stream << "__asm__ __volatile__(\"cp.async.commit_group\");\n"; + this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { std::string N = this->PrintExpr(op->args[0]); - this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + "\");\n"; + this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 405ee68aac8c..ca2a4b75da4e 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -651,15 +651,16 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, : "l"((void *)({smem_addr})) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%1], [%2], %3;\n" - :"r"(addr), "l"(global_ptr), "n"(bytes) + "cp.async.ca.shared.global [%0], [%1], 16;" + : "=r"(addr) + : "l"((void*)({global_ptr})) ); } )"; Replacer replacer; replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); - replacer.register_rule("{bytes} ", bytes); + // replacer.register_rule("{bytes} ", bytes); asm_code = replacer.rewrite(asm_code); return asm_code; } diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index 60aba7f3986e..efbd3696978e 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -23,7 +23,7 @@ @T.prim_func def ptx_cp_async( - A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(16, 128), "float16"] + A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"] ) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") @@ -32,13 +32,15 @@ def ptx_cp_async( T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128]) for i in range(16): T.evaluate( T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16") ) - T.ptx_wait_group(0) + T.evaluate(T.ptx_wait_group(0, dtype="float16")) for i in range(128): B[tx, i] = A_shared[tx, i] @@ -48,10 +50,20 @@ def ptx_cp_async( def test_ptx_cp_async(): f = ptx_cp_async arch = tvm.contrib.nvcc.get_target_compute_version() - major, minor = tvm.contrib.nvcc.parse_compute_version(arch) - if major * 10 + minor < 80: + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: # Require at least SM80 return + mod = tvm.build(f, target="cuda") + A_np = np.random.rand(32, 128).astype("float16") + B_np = np.zeros((32, 128)).astype("float16") + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + if __name__ == "__main__": test_ptx_cp_async() From baaf4b8e22a197d87be3c276b61fe3b706c9422e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 17:03:02 +0900 Subject: [PATCH 6/8] codegen bug fixed, test working --- src/target/source/ptx.cc | 7 +++---- tests/python/unittest/test_tir_ptx_cp_async.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index ca2a4b75da4e..71c68baed6dc 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -651,16 +651,15 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, : "l"((void *)({smem_addr})) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], 16;" - : "=r"(addr) - : "l"((void*)({global_ptr})) + "cp.async.cg.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) ); } )"; Replacer replacer; replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); - // replacer.register_rule("{bytes} ", bytes); + replacer.register_rule("{bytes}", bytes); asm_code = replacer.rewrite(asm_code); return asm_code; } diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index efbd3696978e..d3b89d56b8d0 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -38,7 +38,7 @@ def ptx_cp_async( for i in range(16): T.evaluate( T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16") - ) + ) T.evaluate(T.ptx_wait_group(0, dtype="float16")) From 8aa591e90ca5abd535dd28d1b293d92df5d6c1f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 17:05:51 +0900 Subject: [PATCH 7/8] add commit group --- tests/python/unittest/test_tir_ptx_cp_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index d3b89d56b8d0..2d46d9fb3090 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -40,6 +40,7 @@ def ptx_cp_async( T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16") ) + T.evaluate(T.ptx_commit_group(dtype="float16")) T.evaluate(T.ptx_wait_group(0, dtype="float16")) for i in range(128): From 0c704d17f8bbff6b7c819df66504b8e37e41b90e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 17:13:26 +0900 Subject: [PATCH 8/8] add doc --- include/tvm/tir/builtin.h | 15 +++++++++++++++ src/target/source/ptx.h | 8 ++++++++ tests/python/unittest/test_tir_ptx_cp_async.py | 10 +++++----- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 5a166b2080e4..f33432645cc3 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -632,7 +632,22 @@ TVM_DLL const Op& ptx_mma_sp(); */ TVM_DLL const Op& ptx_ldmatrix(); +/*! + * \brief tvm intrinsics for ptx async copy from global to shared memory + * + * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t + * bytes); + * + */ TVM_DLL const Op& ptx_cp_async(); + +/*! + * \brief tvm intrinsics for ptx async copy commit and wait. + * + * void ptx_commit_group(); + * void ptx_wait_group(int num); + * + */ TVM_DLL const Op& ptx_commit_group(); TVM_DLL const Op& ptx_wait_group(); diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index fd131822ff9f..c811a1b9c1d6 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -79,6 +79,14 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type const std::string& smem_ptr, const std::string& smem_elem_offset); +/*! + * \brief Print ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + */ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& shared_elem_offset, const std::string& global_ptr, diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index 2d46d9fb3090..17b60885509f 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import tvm from tvm.script import tir as T import numpy as np @@ -22,9 +21,7 @@ @T.prim_func -def ptx_cp_async( - A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"] -) -> None: +def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") @@ -37,9 +34,12 @@ def ptx_cp_async( for i in range(16): T.evaluate( - T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16") + T.ptx_cp_async( + A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16" ) + ) + # TODO(masahi): Remove dtype requirement from TVMScript parser T.evaluate(T.ptx_commit_group(dtype="float16")) T.evaluate(T.ptx_wait_group(0, dtype="float16"))