Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon Use…
Browse files Browse the repository at this point in the history
…r DMA (apache#12785)

* [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to HexagonUserDMA

* save queue ID in `copy`, inspect in `wait` transform; add comments

* improve testing; parameters for shape, scope, dtype

* add log statements and adjust comments to clarify pass behavior

* generalize use_async_copy for pass enable

* use DLOG instead of LOG

* trigger ci

* trigger ci again
  • Loading branch information
adstraw authored and xinetzone committed Nov 25, 2022
1 parent 323b1f8 commit b0d5de4
Show file tree
Hide file tree
Showing 10 changed files with 367 additions and 7 deletions.
10 changes: 10 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ TVM_DLL const Op& texture2d_load();
*/
TVM_DLL const Op& mem_copy();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
*/
TVM_DLL const Op& dma_wait();

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ TVM_DLL Pass TextureFlatten();
*/
TVM_DLL Pass LowerVtcmAlloc();

/*!
* \brief Lower Async TIR primitives to DMA copy and wait builtins
*/
TVM_DLL Pass LowerAsyncDMA();

/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
Expand Down
12 changes: 8 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);

using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down Expand Up @@ -225,6 +225,11 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

if (use_async_copy) {
pass_list.push_back(tir::transform::LowerAsyncDMA());
}
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
Expand Down Expand Up @@ -543,10 +548,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

bool use_ptx_async_copy =
pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

if (use_ptx_async_copy) {
if (use_async_copy) {
mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
}

Expand Down
25 changes: 25 additions & 0 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "../workspace_pool.h"
#include "hexagon_common.h"
#include "hexagon_user_dma.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -206,6 +207,30 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
void* dst = args[1];
void* src = args[2];
int size = args[3];
ICHECK(size > 0);

int ret = DMA_RETRY;
do {
ret = HexagonUserDMA::Get().Copy(dst, src, size);
} while (ret == DMA_RETRY);
*rv = static_cast<int32_t>(ret);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonUserDMA::Get().Wait(inflight);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);
Expand Down
194 changes: 194 additions & 0 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file lower_async_dma.cc
*/

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "ir_utils.h"

namespace tvm {
namespace tir {

class AsyncDMALowerer : public StmtExprMutator {
public:
AsyncDMALowerer() {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
// Convert this, for example:
// attr [0] "async_wait_queue_scope" = 0;
// attr [0] "async_wait_inflight_count" = 0;
//
// To this:
// @tir.dma_wait(
// 0, /* queue id */
// 0, /* in flight count */
// dtype=int32
// )
if (op->attr_key == tir::attr::async_wait_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// abort if we have not seen this queue ID in `copy` transform
if (queue_ids.find(queue_id) == queue_ids.end()) {
DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
"`async_wait_queue_scope` transform has not been previously observed in the "
"`async_commit_queue_scope` transform";
return StmtExprMutator::VisitStmt_(op);
}

auto async_wait = op->body.as<AttrStmtNode>();
if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_wait_queue_scope` does not contain an `AttrStmtNode` with key "
"`async_wait_inflight_count`";
return StmtExprMutator::VisitStmt_(op);
}

auto call_dma_wait =
Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));

// concatenate the call with the body and return
return SeqStmt({call_dma_wait, async_wait->body});

// Convert this, for example:
// attr [0] "async_commit_queue_scope" = 0;
// attr [0] "async_scope" = 1;
// for (ax0: int32, 0, 128) {
// A_global[ax0] = A[ax0]
// }
//
// To this:
// @tir.dma_copy(
// 0, /* queue id */
// @tir.address_of(A_global[0], dtype=handle),
// @tir.address_of(A[0], dtype=handle),
// 128, /* size */
// dtype=int32
// )
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// save queue ID for inspection in `wait` transform
queue_ids.insert(queue_id);

// walk the graph to verify this is a mem copy ...
// 1) async_commit_queue_scope contains async_scope
auto async_scope = op->body.as<AttrStmtNode>();
if (!async_scope || async_scope->attr_key != tir::attr::async_scope) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_commit_queue_scope` does not contain an `AttrStmtNode` with key "
"`async_scope`";
return StmtExprMutator::VisitStmt_(op);
}

// 2) async_scope contains single for loop
auto for_loop = async_scope->body.as<ForNode>();
if (!for_loop) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_scope` does not contain a single `ForNode`";
return StmtExprMutator::VisitStmt_(op);
}

// 3) for loop contains buffer store with single index
auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
DLOG(INFO)
<< "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a "
"single `BufferStoreNode` with a single index variable";
return StmtExprMutator::VisitStmt_(op);
}

// 4) buffer store value is a buffer load with single index
auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
if (!bufferloadnode || bufferloadnode->indices.size() != 1) {
DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a "
"single `BufferLoadNode` with a single index variable";
return StmtExprMutator::VisitStmt_(op);
}

// get store buffer; assert it exists and is contiguous given it uses a single index
auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
ICHECK(bufferstore && bufferstore->strides.empty());

// get load buffer; assert it exists and is contiguous given it uses a single index
auto bufferload = bufferloadnode->buffer.as<BufferNode>();
ICHECK(bufferload && bufferload->strides.empty());

// we will be replacing the entire for loop including its index
// with a DMA copy instrinsic that spans the entire index space of the for loop
// so we will need to replace the for loop index with value zero in the buffer indices
// thus we eliminate the index from the expression so the DMA copy receives the buffer range
// base address
Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}};

// map loop variable to zero for the store index & simplify
Array<PrimExpr> store_index = bufferstorenode->indices;
store_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

// map loop variable to zero for the load index & simplify
Array<PrimExpr> load_index = bufferloadnode->indices;
load_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
{queue_id,
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferstorenode->buffer, store_index)}),
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferloadnode->buffer, load_index)}),
for_loop->extent * bufferloadnode->dtype.bytes()}));
}
return StmtExprMutator::VisitStmt_(op);
}

private:
std::set<int> queue_ids;
};

namespace transform {

Pass LowerAsyncDMA() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto fptr = f.CopyOnWrite();
fptr->body = AsyncDMALowerer()(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA);
} // namespace transform

} // namespace tir
} // namespace tvm
30 changes: 30 additions & 0 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ class BuiltinLower : public StmtExprMutator {
return make_zero(op->dtype);
} else if (op->op.same_as(builtin::mem_copy())) {
return MakeMemCopy(op);
} else if (op->op.same_as(builtin::dma_copy())) {
return MakeDMACopy(op);
} else if (op->op.same_as(builtin::dma_wait())) {
return MakeDMAWait(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand All @@ -335,6 +339,32 @@ class BuiltinLower : public StmtExprMutator {
return VisitExpr(call_packed);
}

PrimExpr MakeDMACopy(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr dst = op->args[1];
PrimExpr src = op->args[2];
PrimExpr size = op->args[3];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size});
return VisitExpr(call_packed);
}

PrimExpr MakeDMAWait(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr inflight = op->args[1];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight});
return VisitExpr(call_packed);
}

// call shape
PrimExpr MakeShape(const CallNode* op) {
// if args.size() == 0, it represents a scalar shape ()
Expand Down
Loading

0 comments on commit b0d5de4

Please sign in to comment.