Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA #12785

Merged
merged 8 commits into from
Sep 20, 2022
Merged
10 changes: 10 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ TVM_DLL const Op& texture2d_load();
*/
TVM_DLL const Op& mem_copy();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
*/
TVM_DLL const Op& dma_wait();
adstraw marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ TVM_DLL Pass TextureFlatten();
*/
TVM_DLL Pass LowerVtcmAlloc();

/*!
* \brief Lower Async TIR primitives to DMA copy and wait builtins
*/
TVM_DLL Pass LowerAsyncDMA();

/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::LowerAsyncDMA());
adstraw marked this conversation as resolved.
Show resolved Hide resolved
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "../workspace_pool.h"
#include "hexagon_common.h"
#include "hexagon_user_dma.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -206,6 +207,30 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
void* dst = args[1];
void* src = args[2];
int size = args[3];
ICHECK(size >= 0);

int ret = DMA_RETRY;
do {
ret = HexagonUserDMA::Get().Copy(dst, src, size);
} while (ret == DMA_RETRY);
*rv = static_cast<int32_t>(ret);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonUserDMA::Get().Wait(inflight);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);
Expand Down
172 changes: 172 additions & 0 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file lower_async_dma.cc
*/

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "ir_utils.h"

namespace tvm {
namespace tir {

class AsyncDMALowerer : public StmtExprMutator {
public:
AsyncDMALowerer() {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
// Convert this, for example:
// attr [0] "async_wait_queue_scope" = 0;
// attr [0] "async_wait_inflight_count" = 0;
//
// To this:
// @tir.dma_wait(
// 0, /* queue id */
// 0, /* in flight count */
// dtype=int32
// )
if (op->attr_key == tir::attr::async_wait_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// abort if we have not seen this queue ID in `copy` transform
if (queue_ids.find(queue_id) == queue_ids.end()) {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
return StmtExprMutator::VisitStmt_(op);
}

auto async_wait = op->body.as<AttrStmtNode>();
ICHECK(async_wait && async_wait->attr_key == tir::attr::async_wait_inflight_count);

auto call_dma_wait =
Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));

// concatenate the call with the body and return
return SeqStmt({call_dma_wait, async_wait->body});

// Convert this, for example:
// attr [0] "async_commit_queue_scope" = 0;
// attr [0] "async_scope" = 1;
// for (ax0: int32, 0, 128) {
// A_global[ax0] = A[ax0]
// }
//
// To this:
// @tir.dma_copy(
// 0, /* queue id */
// @tir.address_of(A_global[0], dtype=handle),
// @tir.address_of(A[0], dtype=handle),
// 128, /* size */
// dtype=int32
// )
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// save queue ID for inspection in `wait` transform
queue_ids.insert(queue_id);

// walk the graph to verify this is a mem copy ...
// 1) async_commit_queue_scope contains async_scope
auto async_scope = op->body.as<AttrStmtNode>();
ICHECK(async_scope && async_scope->attr_key == tir::attr::async_scope);

// 2) async_scope contains single for loop
auto for_loop = async_scope->body.as<ForNode>();
if (!for_loop) {
return StmtExprMutator::VisitStmt_(op);
}

// 3) for loop contains buffer store with single index
auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
return StmtExprMutator::VisitStmt_(op);
}

// 4) buffer store value is a buffer load with single index
auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
if (!bufferloadnode || bufferloadnode->indices.size() != 1) {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
return StmtExprMutator::VisitStmt_(op);
}

// get store buffer and assert that it is contiguous
auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
ICHECK(bufferstore && bufferstore->strides.empty());

// get load buffer and assert that it is contiguous
auto bufferload = bufferloadnode->buffer.as<BufferNode>();
ICHECK(bufferload && bufferload->strides.empty());
adstraw marked this conversation as resolved.
Show resolved Hide resolved

// we will be replacing the entire for loop including its index
// with a DMA copy instrinsic that spans the entire index space of the for loop
// so we will need to repace the for loop index with value zero in the buffer indices
adstraw marked this conversation as resolved.
Show resolved Hide resolved
Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}};

// map loop variable to zero for the store index & simplify
Array<PrimExpr> store_index = bufferstorenode->indices;
store_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

// map loop variable to zero for the load index & simplify
Array<PrimExpr> load_index = bufferloadnode->indices;
load_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
{queue_id,
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferstorenode->buffer, store_index)}),
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferloadnode->buffer, load_index)}),
for_loop->extent * bufferloadnode->dtype.bytes()}));
}
return StmtExprMutator::VisitStmt_(op);
}

private:
std::set<int> queue_ids;
};

namespace transform {

Pass LowerAsyncDMA() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto fptr = f.CopyOnWrite();
fptr->body = AsyncDMALowerer()(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA);
} // namespace transform

} // namespace tir
} // namespace tvm
30 changes: 30 additions & 0 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ class BuiltinLower : public StmtExprMutator {
return make_zero(op->dtype);
} else if (op->op.same_as(builtin::mem_copy())) {
return MakeMemCopy(op);
} else if (op->op.same_as(builtin::dma_copy())) {
return MakeDMACopy(op);
} else if (op->op.same_as(builtin::dma_wait())) {
return MakeDMAWait(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand All @@ -335,6 +339,32 @@ class BuiltinLower : public StmtExprMutator {
return VisitExpr(call_packed);
}

PrimExpr MakeDMACopy(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr dst = op->args[1];
PrimExpr src = op->args[2];
PrimExpr size = op->args[3];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size});
return VisitExpr(call_packed);
}

PrimExpr MakeDMAWait(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr inflight = op->args[1];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight});
return VisitExpr(call_packed);
}

// call shape
PrimExpr MakeShape(const CallNode* op) {
// if args.size() == 0, it represents a scalar shape ()
Expand Down
72 changes: 72 additions & 0 deletions tests/python/contrib/test_hexagon/test_software_pipeline_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
import pytest
import numpy as np

import tvm
from tvm import tir
from tvm.contrib.hexagon.session import Session
from tvm.script import tir as T

outer = 16
inner = 128


@T.prim_func
def plus_one_primfunc(A: T.Buffer[(outer, inner), "uint8"], B: T.Buffer[(outer, inner), "uint8"]):
for i in T.serial(outer):
for j in T.serial(inner):
with T.block("plus_one"):
with T.block():
B[i, j] = A[i, j] + T.uint8(1)


@tvm.testing.requires_hexagon
def test_software_pipeline_with_cache_read(hexagon_launcher):
sch = tir.Schedule(plus_one_primfunc)
root = sch.get_block("root")
plus_one = sch.get_block("plus_one")
cache_read_block = sch.cache_read(plus_one, 0, "global")

i, j = sch.get_loops(plus_one)
sch.compute_at(cache_read_block, i)
sch.annotate(i, "software_pipeline_stage", [0, 1])
sch.annotate(i, "software_pipeline_order", [0, 1])
sch.annotate(i, "software_pipeline_async_stages", [0])

tvm.lower(sch.mod["main"]).show()

target_hexagon = tvm.target.hexagon("v68", link_params=True)
func = tvm.build(sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon))

with hexagon_launcher.start_session() as hexagon_session:
mod = hexagon_session.load_module(func)
dev = hexagon_session.device

a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype("uint8")
b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype("uint8")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
mod(a, b)
ref = a_np + 1
np.testing.assert_equal(b.numpy(), ref)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))