From 07592d2be8a4afc08c370de7d5341e6557cb53bf Mon Sep 17 00:00:00 2001 From: adstraw Date: Wed, 31 Aug 2022 09:23:52 -0700 Subject: [PATCH 1/8] [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to HexagonUserDMA --- include/tvm/tir/builtin.h | 10 ++ include/tvm/tir/transform.h | 5 + src/driver/driver_api.cc | 1 + src/runtime/hexagon/hexagon_device_api.cc | 25 +++ src/tir/op/builtin.cc | 6 + src/tir/transforms/lower_async_dma.cc | 144 ++++++++++++++++++ src/tir/transforms/lower_tvm_builtin.cc | 30 ++++ .../test_software_pipeline_async.py | 70 +++++++++ 8 files changed, 291 insertions(+) create mode 100644 src/tir/transforms/lower_async_dma.cc create mode 100644 tests/python/contrib/test_hexagon/test_software_pipeline_async.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 12290a97c840..a1a97595bfd8 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -720,6 +720,16 @@ TVM_DLL const Op& texture2d_load(); */ TVM_DLL const Op& mem_copy(); +/*! + * \brief Initiate a non-blocking DMA copy from source to destination + */ +TVM_DLL const Op& dma_copy(); + +/*! + * \brief Wait until the number of DMAs in flight is less than or equal to some maximum + */ +TVM_DLL const Op& dma_wait(); + /*! * \brief Provide a true statement that can be used for simplifications * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index fd4261e4a4e3..a4caeee43604 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -485,6 +485,11 @@ TVM_DLL Pass TextureFlatten(); */ TVM_DLL Pass LowerVtcmAlloc(); +/*! + * \brief Lower Async TIR primitives to DMA copy and wait builtins + */ +TVM_DLL Pass LowerAsyncDMA(); + /*! * \brief Implements a Common Subexpression Elimination (CSE) for TIR * which introduces let-in bindings for duplicated sub-expressions. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e528686d967d..54b6f59a675f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -225,6 +225,7 @@ Array CreatePassList(bool disable_loop_partition) { } // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations pass_list.push_back(tir::transform::LowerVtcmAlloc()); + pass_list.push_back(tir::transform::LowerAsyncDMA()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index fd3a0db2025b..cd881532a49e 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -33,6 +33,7 @@ #include "../workspace_pool.h" #include "hexagon_common.h" +#include "hexagon_user_dma.h" namespace tvm { namespace runtime { @@ -206,6 +207,30 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM *rv = static_cast(0); }); +TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) { + int queue_id = args[0]; + ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA"); + void* dst = args[1]; + void* src = args[2]; + int size = args[3]; + ICHECK(size >= 0); + + int ret = DMA_RETRY; + do { + ret = HexagonUserDMA::Get().Copy(dst, src, size); + } while (ret == DMA_RETRY); + *rv = static_cast(ret); +}); + +TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) { + int queue_id = args[0]; + ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA"); + int inflight = args[1]; + ICHECK(inflight >= 0); + HexagonUserDMA::Get().Wait(inflight); + *rv = static_cast(0); +}); + TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) { int32_t device_type = args[0]; int32_t device_id = args[1]; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 9642f8e39f39..1e2d790c76e1 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -288,6 +288,12 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load) TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) .set_num_inputs(1); diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc new file mode 100644 index 000000000000..a58112fc8195 --- /dev/null +++ b/src/tir/transforms/lower_async_dma.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_async_dma.cc + */ + +#include +#include +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +class AsyncDMALowerer : public StmtExprMutator { + public: + AsyncDMALowerer() {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + // Convert this, for example: + // attr [0] "async_wait_queue_scope" = 0; + // attr [0] "async_wait_inflight_count" = 0; + // + // To this: + // @tir.dma_wait( + // 0, /* queue id */ + // 0, /* in flight count */ + // dtype=int32 + // ) + if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto async_wait = op->body.as(); + ICHECK(async_wait && async_wait->attr_key == tir::attr::async_wait_inflight_count); + + auto call_dma_wait = + Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {op->value, async_wait->value})); + + // concatenate the call with the body and return + return SeqStmt({call_dma_wait, async_wait->body}); + + // Convert this, for example: + // attr [0] "async_commit_queue_scope" = 0; + // attr [0] "async_scope" = 1; + // for (ax0: int32, 0, 128) { + // A_global[ax0] = A[ax0] + // } + // + // To this: + // @tir.dma_copy( + // 0, /* queue id */ + // @tir.address_of(A_global[0], dtype=handle), + // @tir.address_of(A[0], dtype=handle), + // 128, /* size */ + // dtype=int32 + // ) + } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + auto async_scope = op->body.as(); + ICHECK(async_scope && async_scope->attr_key == tir::attr::async_scope); + + auto for_loop = async_scope->body.as(); + if (!for_loop) { + return StmtExprMutator::VisitStmt_(op); + } + + auto bufferstorenode = for_loop->body.as(); + if (!bufferstorenode) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK(bufferstorenode->indices.size() == 1); + + auto bufferloadnode = bufferstorenode->value.as(); + if (!bufferloadnode) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK(bufferloadnode->indices.size() == 1); + + auto bufferstore = bufferstorenode->buffer.as(); + ICHECK(bufferstore && bufferstore->strides.empty()); + + auto bufferload = bufferloadnode->buffer.as(); + ICHECK(bufferload && bufferload->strides.empty()); + + // map loop variable to zero + Map loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}}; + + Array store_indices = bufferstorenode->indices; + store_indices.MutateByApply([&](PrimExpr expr) { + arith::Analyzer analyzer; + return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); + }); + + Array load_indices = bufferloadnode->indices; + load_indices.MutateByApply([&](PrimExpr expr) { + arith::Analyzer analyzer; + return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); + }); + + return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), + {op->value, + Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(bufferstorenode->buffer, store_indices)}), + Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(bufferloadnode->buffer, load_indices)}), + for_loop->extent * bufferloadnode->dtype.bytes()})); + } + return StmtExprMutator::VisitStmt_(op); + } +}; + +namespace transform { + +Pass LowerAsyncDMA() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto fptr = f.CopyOnWrite(); + fptr->body = AsyncDMALowerer()(std::move(fptr->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 9d0087cc7a0b..f79682ef7ecc 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -317,6 +317,10 @@ class BuiltinLower : public StmtExprMutator { return make_zero(op->dtype); } else if (op->op.same_as(builtin::mem_copy())) { return MakeMemCopy(op); + } else if (op->op.same_as(builtin::dma_copy())) { + return MakeDMACopy(op); + } else if (op->op.same_as(builtin::dma_wait())) { + return MakeDMAWait(op); } else { return StmtExprMutator::VisitExpr_(op); } @@ -335,6 +339,32 @@ class BuiltinLower : public StmtExprMutator { return VisitExpr(call_packed); } + PrimExpr MakeDMACopy(const CallNode* op) { + PrimExpr queue_id = op->args[0]; + PrimExpr dst = op->args[1]; + PrimExpr src = op->args[2]; + PrimExpr size = op->args[3]; + + std::string fdevapi_prefix = + "device_api." + std::string(runtime::DeviceName(device_type_.as()->value)); + + Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size}); + return VisitExpr(call_packed); + } + + PrimExpr MakeDMAWait(const CallNode* op) { + PrimExpr queue_id = op->args[0]; + PrimExpr inflight = op->args[1]; + + std::string fdevapi_prefix = + "device_api." + std::string(runtime::DeviceName(device_type_.as()->value)); + + Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight}); + return VisitExpr(call_packed); + } + // call shape PrimExpr MakeShape(const CallNode* op) { // if args.size() == 0, it represents a scalar shape () diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py new file mode 100644 index 000000000000..963220af5670 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import pytest +import numpy as np + +import tvm +from tvm import tir +from tvm.contrib.hexagon.session import Session +from tvm.script import tir as T + +outer = 16 +inner = 128 + + +@T.prim_func +def plus_one_primfunc(A: T.Buffer[(outer, inner), "uint8"], B: T.Buffer[(outer, inner), "uint8"]): + for i in T.serial(outer): + for j in T.serial(inner): + with T.block("plus_one"): + with T.block(): + B[i, j] = A[i, j] + T.uint8(1) + + +@tvm.testing.requires_hexagon +def test_software_pipeline_with_cache_read(hexagon_launcher): + sch = tir.Schedule(plus_one_primfunc) + root = sch.get_block("root") + plus_one = sch.get_block("plus_one") + cache_read_block = sch.cache_read(plus_one, 0, "global") + + i, j = sch.get_loops(plus_one) + sch.compute_at(cache_read_block, i) + sch.annotate(i, "software_pipeline_stage", [0, 1]) + sch.annotate(i, "software_pipeline_order", [0, 1]) + sch.annotate(i, "software_pipeline_async_stages", [0]) + + target_hexagon = tvm.target.hexagon("v68", link_params=True) + func = tvm.build(sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)) + + with hexagon_launcher.start_session() as hexagon_session: + mod = hexagon_session.load_module(func) + dev = hexagon_session.device + + a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype("uint8") + b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype("uint8") + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + mod(a, b) + ref = a_np + 1 + np.testing.assert_equal(b.numpy(), ref) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) From 450b3b32921e59279a48610b2b90dec71af34150 Mon Sep 17 00:00:00 2001 From: adstraw Date: Thu, 15 Sep 2022 09:17:37 -0700 Subject: [PATCH 2/8] save queue ID in `copy`, inspect in `wait` transform; add comments --- src/tir/transforms/lower_async_dma.cc | 58 ++++++++++++++----- .../test_software_pipeline_async.py | 2 + 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index a58112fc8195..7e829b10f596 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -46,11 +46,21 @@ class AsyncDMALowerer : public StmtExprMutator { // dtype=int32 // ) if (op->attr_key == tir::attr::async_wait_queue_scope) { + // get queue ID + auto queue_id_node = op->value.as(); + ICHECK(queue_id_node); + int queue_id = queue_id_node->value; + + // abort if we have not seen this queue ID in `copy` transform + if (queue_ids.find(queue_id) == queue_ids.end()) { + return StmtExprMutator::VisitStmt_(op); + } + auto async_wait = op->body.as(); ICHECK(async_wait && async_wait->attr_key == tir::attr::async_wait_inflight_count); auto call_dma_wait = - Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {op->value, async_wait->value})); + Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value})); // concatenate the call with the body and return return SeqStmt({call_dma_wait, async_wait->body}); @@ -71,59 +81,77 @@ class AsyncDMALowerer : public StmtExprMutator { // dtype=int32 // ) } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + // get queue ID + auto queue_id_node = op->value.as(); + ICHECK(queue_id_node); + int queue_id = queue_id_node->value; + + // save queue ID for inspection in `wait` transform + queue_ids.insert(queue_id); + + // walk the graph to verify this is a mem copy ... + // 1) async_commit_queue_scope contains async_scope auto async_scope = op->body.as(); ICHECK(async_scope && async_scope->attr_key == tir::attr::async_scope); + // 2) async_scope contains single for loop auto for_loop = async_scope->body.as(); if (!for_loop) { return StmtExprMutator::VisitStmt_(op); } + // 3) for loop contains buffer store with single index auto bufferstorenode = for_loop->body.as(); - if (!bufferstorenode) { + if (!bufferstorenode || bufferstorenode->indices.size() != 1) { return StmtExprMutator::VisitStmt_(op); } - ICHECK(bufferstorenode->indices.size() == 1); - + // 4) buffer store value is a buffer load with single index auto bufferloadnode = bufferstorenode->value.as(); - if (!bufferloadnode) { + if (!bufferloadnode || bufferloadnode->indices.size() != 1) { return StmtExprMutator::VisitStmt_(op); } - ICHECK(bufferloadnode->indices.size() == 1); - + // get store buffer and assert that it is contiguous auto bufferstore = bufferstorenode->buffer.as(); ICHECK(bufferstore && bufferstore->strides.empty()); + // get load buffer and assert that it is contiguous auto bufferload = bufferloadnode->buffer.as(); ICHECK(bufferload && bufferload->strides.empty()); - // map loop variable to zero + // we will be replacing the entire for loop including its index + // with a DMA copy instrinsic that spans the entire index space of the for loop + // so we will need to repace the for loop index with value zero in the buffer indices Map loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}}; - Array store_indices = bufferstorenode->indices; - store_indices.MutateByApply([&](PrimExpr expr) { + // map loop variable to zero for the store index & simplify + Array store_index = bufferstorenode->indices; + store_index.MutateByApply([&](PrimExpr expr) { arith::Analyzer analyzer; return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); }); - Array load_indices = bufferloadnode->indices; - load_indices.MutateByApply([&](PrimExpr expr) { + // map loop variable to zero for the load index & simplify + Array load_index = bufferloadnode->indices; + load_index.MutateByApply([&](PrimExpr expr) { arith::Analyzer analyzer; return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); }); return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), - {op->value, + {queue_id, Call(DataType::Handle(), builtin::address_of(), - {BufferLoad(bufferstorenode->buffer, store_indices)}), + {BufferLoad(bufferstorenode->buffer, store_index)}), Call(DataType::Handle(), builtin::address_of(), - {BufferLoad(bufferloadnode->buffer, load_indices)}), + {BufferLoad(bufferloadnode->buffer, load_index)}), for_loop->extent * bufferloadnode->dtype.bytes()})); } return StmtExprMutator::VisitStmt_(op); } + + private: + std::set queue_ids; }; namespace transform { diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 963220af5670..69c1757ba460 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -50,6 +50,8 @@ def test_software_pipeline_with_cache_read(hexagon_launcher): sch.annotate(i, "software_pipeline_order", [0, 1]) sch.annotate(i, "software_pipeline_async_stages", [0]) + tvm.lower(sch.mod["main"]).show() + target_hexagon = tvm.target.hexagon("v68", link_params=True) func = tvm.build(sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)) From b95f1e1c68ef8266bfaeb5503f184aaccb8891c8 Mon Sep 17 00:00:00 2001 From: adstraw Date: Fri, 16 Sep 2022 15:58:59 -0700 Subject: [PATCH 3/8] improve testing; parameters for shape, scope, dtype --- .../test_software_pipeline_async.py | 57 +++++++++++-------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 69c1757ba460..b28cca77de3e 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -24,48 +24,59 @@ from tvm.contrib.hexagon.session import Session from tvm.script import tir as T -outer = 16 -inner = 128 +outer = tvm.testing.parameter(8, 16) +inner = tvm.testing.parameter(64, 128) +scope = tvm.testing.parameter("global", "global.vtcm") +dtype = tvm.testing.parameter("uint8", "float16") -@T.prim_func -def plus_one_primfunc(A: T.Buffer[(outer, inner), "uint8"], B: T.Buffer[(outer, inner), "uint8"]): - for i in T.serial(outer): - for j in T.serial(inner): - with T.block("plus_one"): - with T.block(): - B[i, j] = A[i, j] + T.uint8(1) +@tvm.testing.fixture +def compute(outer, inner, dtype): + @T.prim_func + def plus_one_primfunc(A: T.Buffer[(outer, inner), dtype], B: T.Buffer[(outer, inner), dtype]): + for i in T.serial(outer): + for j in T.serial(inner): + with T.block("compute"): + with T.block(): + B[i, j] = A[i, j] + T.cast(1, dtype) + + def plus_one_ref(a): + return a + 1 + + return plus_one_primfunc, plus_one_ref @tvm.testing.requires_hexagon -def test_software_pipeline_with_cache_read(hexagon_launcher): - sch = tir.Schedule(plus_one_primfunc) +def test_software_pipeline_with_cache_read(hexagon_launcher, compute, outer, inner, dtype, scope): + sch = tir.Schedule(compute[0]) root = sch.get_block("root") - plus_one = sch.get_block("plus_one") - cache_read_block = sch.cache_read(plus_one, 0, "global") + compute_block = sch.get_block("compute") + cache_read_block = sch.cache_read(compute_block, 0, scope) - i, j = sch.get_loops(plus_one) + i, _ = sch.get_loops(compute_block) sch.compute_at(cache_read_block, i) sch.annotate(i, "software_pipeline_stage", [0, 1]) sch.annotate(i, "software_pipeline_order", [0, 1]) sch.annotate(i, "software_pipeline_async_stages", [0]) - tvm.lower(sch.mod["main"]).show() + a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + ref = compute[1](a_np) target_hexagon = tvm.target.hexagon("v68", link_params=True) func = tvm.build(sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)) with hexagon_launcher.start_session() as hexagon_session: - mod = hexagon_session.load_module(func) dev = hexagon_session.device - - a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype("uint8") - b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype("uint8") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) + a = tvm.nd.array(a_np, device=dev) + b = tvm.nd.array(b_np, device=dev) + mod = hexagon_session.load_module(func) mod(a, b) - ref = a_np + 1 - np.testing.assert_equal(b.numpy(), ref) + + if "int" in dtype: + np.testing.assert_equal(b.numpy(), ref) + else: + np.testing.assert_allclose(b.numpy(), ref, rtol=1e-3, atol=1e-3) if __name__ == "__main__": From 501a40cc7292a48b65e85f3050c5828a935331bd Mon Sep 17 00:00:00 2001 From: adstraw Date: Mon, 19 Sep 2022 08:48:40 -0700 Subject: [PATCH 4/8] add log statements and adjust comments to clarify pass behavior --- src/tir/transforms/lower_async_dma.cc | 31 ++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 7e829b10f596..fa181d9c9f07 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -53,11 +53,19 @@ class AsyncDMALowerer : public StmtExprMutator { // abort if we have not seen this queue ID in `copy` transform if (queue_ids.find(queue_id) == queue_ids.end()) { + LOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the " + "`async_wait_queue_scope` transform has not been previously observed in the " + "`async_commit_queue_scope` transform"; return StmtExprMutator::VisitStmt_(op); } auto async_wait = op->body.as(); - ICHECK(async_wait && async_wait->attr_key == tir::attr::async_wait_inflight_count); + if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) { + LOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key " + "`async_wait_inflight_count`"; + return StmtExprMutator::VisitStmt_(op); + } auto call_dma_wait = Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value})); @@ -92,37 +100,50 @@ class AsyncDMALowerer : public StmtExprMutator { // walk the graph to verify this is a mem copy ... // 1) async_commit_queue_scope contains async_scope auto async_scope = op->body.as(); - ICHECK(async_scope && async_scope->attr_key == tir::attr::async_scope); + if (!async_scope || async_scope->attr_key != tir::attr::async_scope) { + LOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key " + "`async_scope`"; + return StmtExprMutator::VisitStmt_(op); + } // 2) async_scope contains single for loop auto for_loop = async_scope->body.as(); if (!for_loop) { + LOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_scope` does not contain a single `ForNode`"; return StmtExprMutator::VisitStmt_(op); } // 3) for loop contains buffer store with single index auto bufferstorenode = for_loop->body.as(); if (!bufferstorenode || bufferstorenode->indices.size() != 1) { + LOG(INFO) << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a " + "single `BufferStoreNode` with a single index variable"; return StmtExprMutator::VisitStmt_(op); } // 4) buffer store value is a buffer load with single index auto bufferloadnode = bufferstorenode->value.as(); if (!bufferloadnode || bufferloadnode->indices.size() != 1) { + LOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a " + "single `BufferLoadNode` with a single index variable"; return StmtExprMutator::VisitStmt_(op); } - // get store buffer and assert that it is contiguous + // get store buffer; assert it exists and is contiguous given it uses a single index auto bufferstore = bufferstorenode->buffer.as(); ICHECK(bufferstore && bufferstore->strides.empty()); - // get load buffer and assert that it is contiguous + // get load buffer; assert it exists and is contiguous given it uses a single index auto bufferload = bufferloadnode->buffer.as(); ICHECK(bufferload && bufferload->strides.empty()); // we will be replacing the entire for loop including its index // with a DMA copy instrinsic that spans the entire index space of the for loop - // so we will need to repace the for loop index with value zero in the buffer indices + // so we will need to replace the for loop index with value zero in the buffer indices + // thus we eliminate the index from the expression so the DMA copy receives the buffer range + // base address Map loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}}; // map loop variable to zero for the store index & simplify From 8cc7f941480cb47d36bdedba3efced412cd8f023 Mon Sep 17 00:00:00 2001 From: adstraw Date: Mon, 19 Sep 2022 17:11:37 -0700 Subject: [PATCH 5/8] generalize use_async_copy for pass enable --- src/driver/driver_api.cc | 13 ++++++++----- src/runtime/hexagon/hexagon_device_api.cc | 2 +- .../test_hexagon/test_software_pipeline_async.py | 5 ++++- .../test_tir_transform_inject_ptx_async_copy.py | 4 ++-- .../test_tir_transform_inject_software_pipeline.py | 2 +- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 54b6f59a675f..1a617dcd494d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -50,7 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -225,7 +225,11 @@ Array CreatePassList(bool disable_loop_partition) { } // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations pass_list.push_back(tir::transform::LowerVtcmAlloc()); - pass_list.push_back(tir::transform::LowerAsyncDMA()); + bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); + + if (use_async_copy) { + pass_list.push_back(tir::transform::LowerAsyncDMA()); + } pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes @@ -544,10 +548,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - bool use_ptx_async_copy = - pass_ctx->GetConfig("tir.use_ptx_async_copy", Bool(false)).value(); + bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); - if (use_ptx_async_copy) { + if (use_async_copy) { mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy()); } diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index cd881532a49e..57f2fc3b4b0a 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -213,7 +213,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM void* dst = args[1]; void* src = args[2]; int size = args[3]; - ICHECK(size >= 0); + ICHECK(size > 0); int ret = DMA_RETRY; do { diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index b28cca77de3e..6bcca90ec9d3 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -64,7 +64,10 @@ def test_software_pipeline_with_cache_read(hexagon_launcher, compute, outer, inn ref = compute[1](a_np) target_hexagon = tvm.target.hexagon("v68", link_params=True) - func = tvm.build(sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + func = tvm.build( + sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) + ) with hexagon_launcher.start_session() as hexagon_session: dev = hexagon_session.device diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 1a906b2fb66e..7062d5129713 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -138,7 +138,7 @@ def test_inject_async_copy(): if not tvm.testing.is_ampere_or_newer(): continue - with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}): + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda") A_np = np.random.rand(32, 128).astype(dtype) @@ -166,7 +166,7 @@ def test_inject_async_copy_shared_dyn(): if not tvm.testing.is_ampere_or_newer(): return - with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}): + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda") A_np = np.random.rand(32, 128).astype("float16") diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index edaeb7c9b639..49255e0f2094 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1390,7 +1390,7 @@ def index_map(i, j): def build_and_run(sch): if tvm.testing.is_ampere_or_newer(): - with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}): + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.device("cuda", 0) From d807c3691111d5295bd95ce68708659cec997569 Mon Sep 17 00:00:00 2001 From: adstraw Date: Tue, 20 Sep 2022 08:25:02 -0700 Subject: [PATCH 6/8] use DLOG instead of LOG --- src/tir/transforms/lower_async_dma.cc | 31 ++++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index fa181d9c9f07..78d363f67c02 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -53,17 +53,17 @@ class AsyncDMALowerer : public StmtExprMutator { // abort if we have not seen this queue ID in `copy` transform if (queue_ids.find(queue_id) == queue_ids.end()) { - LOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the " - "`async_wait_queue_scope` transform has not been previously observed in the " - "`async_commit_queue_scope` transform"; + DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the " + "`async_wait_queue_scope` transform has not been previously observed in the " + "`async_commit_queue_scope` transform"; return StmtExprMutator::VisitStmt_(op); } auto async_wait = op->body.as(); if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) { - LOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " - "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key " - "`async_wait_inflight_count`"; + DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key " + "`async_wait_inflight_count`"; return StmtExprMutator::VisitStmt_(op); } @@ -101,33 +101,34 @@ class AsyncDMALowerer : public StmtExprMutator { // 1) async_commit_queue_scope contains async_scope auto async_scope = op->body.as(); if (!async_scope || async_scope->attr_key != tir::attr::async_scope) { - LOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " - "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key " - "`async_scope`"; + DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key " + "`async_scope`"; return StmtExprMutator::VisitStmt_(op); } // 2) async_scope contains single for loop auto for_loop = async_scope->body.as(); if (!for_loop) { - LOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " - "`async_scope` does not contain a single `ForNode`"; + DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_scope` does not contain a single `ForNode`"; return StmtExprMutator::VisitStmt_(op); } // 3) for loop contains buffer store with single index auto bufferstorenode = for_loop->body.as(); if (!bufferstorenode || bufferstorenode->indices.size() != 1) { - LOG(INFO) << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a " - "single `BufferStoreNode` with a single index variable"; + DLOG(INFO) + << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a " + "single `BufferStoreNode` with a single index variable"; return StmtExprMutator::VisitStmt_(op); } // 4) buffer store value is a buffer load with single index auto bufferloadnode = bufferstorenode->value.as(); if (!bufferloadnode || bufferloadnode->indices.size() != 1) { - LOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a " - "single `BufferLoadNode` with a single index variable"; + DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a " + "single `BufferLoadNode` with a single index variable"; return StmtExprMutator::VisitStmt_(op); } From 18420f5df135c8c5b2745cec2661909d2b2f668b Mon Sep 17 00:00:00 2001 From: adstraw Date: Tue, 20 Sep 2022 09:19:08 -0700 Subject: [PATCH 7/8] trigger ci From 24d01fa4dd7de9006d736ac220a52c6d04f75fb7 Mon Sep 17 00:00:00 2001 From: adstraw Date: Tue, 20 Sep 2022 09:27:55 -0700 Subject: [PATCH 8/8] trigger ci again