Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DisallowAsyncStridedMemCopy post processor to rem #13720

Merged
merged 2 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowDynamicLoop();
/*!
* \brief Create a postprocessor that checks if all async mem copies are not strided.
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .disallow_dynamic_loop import DisallowDynamicLoop
from .disallow_async_strided_mem_copy import DisallowAsyncStridedMemCopy
from .postproc import Postproc, PyPostproc
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
from .rewrite_layout import RewriteLayout
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that checks if the IRModule has any strided memory copies"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
class DisallowAsyncStridedMemCopy(Postproc):
"""A postprocessor that disallows schedules that use async strided mem copies.

Parameters
----------
merge_async_commit_queue_scope : bool
Whether or not to merge the async commit queue scope.
"""

def __init__(self, merge_async_commit_queue_scope=True) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
merge_async_commit_queue_scope,
)
189 changes: 189 additions & 0 deletions src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*! \brief Check if an IRModule has any async strided mem copies. */
struct AsyncStridedMemCopyFinder : private StmtExprVisitor {
public:
static bool Find(const IRModule& mod) {
AsyncStridedMemCopyFinder finder;
for (const auto& kv : mod->functions) {
if (const auto* prim_func = kv.second.as<PrimFuncNode>()) {
finder(prim_func->body);
if (finder.found_) {
return true;
}
}
}
return false;
}

private:
void VisitStmt_(const ForNode* loop) final {
if (!found_) {
input_iters.Set(loop->loop_var, Range(loop->min, loop->extent));
StmtExprVisitor::VisitStmt_(loop);
}
}

void VisitStmt_(const AttrStmtNode* attrStmt) final {
if (!found_) {
if (attrStmt->attr_key == tir::attr::async_commit_queue_scope) {
auto async_scope = attrStmt->body.as<AttrStmtNode>();
if (!async_scope) {
StmtExprVisitor::VisitStmt_(attrStmt);
}

auto for_loop = async_scope->body.as<ForNode>();
if (!for_loop) {
StmtExprVisitor::VisitStmt_(attrStmt);
}

input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent));

auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
if (!bufferstorenode) {
StmtExprVisitor::VisitStmt_(attrStmt);
}

auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
if (!bufferloadnode) {
StmtExprVisitor::VisitStmt_(attrStmt);
}

// get store buffer; assert it exists and is contiguous given it uses a single index
auto bufferstore = bufferstorenode->buffer.as<BufferNode>();

// get load buffer; assert it exists and is contiguous given it uses a single index
auto bufferload = bufferloadnode->buffer.as<BufferNode>();

if (!bufferstore || !bufferload) {
StmtExprVisitor::VisitStmt_(attrStmt);
}

// map loop variable to zero for the store index & simplify
Array<PrimExpr> store_index = bufferstorenode->indices;

// Use DetectIterMap to detect whether store index is non-contiguous.
arith::Analyzer analyzer;
auto store_iter_map = DetectIterMap(store_index, input_iters, 1,
arith::IterMapLevel::Surjective, &analyzer, false);
if (!store_iter_map->errors.empty()) {
found_ = true;
}

// map loop variable to zero for the load index & simplify
Array<PrimExpr> load_index = bufferloadnode->indices;

// Use DetectIterMap to detect whether load index is non-contiguous.
auto load_iter_map = DetectIterMap(load_index, input_iters, 1,
arith::IterMapLevel::Surjective, &analyzer, false);
if (!load_iter_map->errors.empty()) {
found_ = true;
}
}
if (!found_) {
StmtExprVisitor::VisitStmt_(attrStmt);
}
}
}

bool found_ = false;
Map<Var, Range> input_iters = Map<Var, Range>();
};

} // namespace tir

namespace meta_schedule {

/*! \brief Check if the IRModule has any loop with non-constant extent. */
class DisallowAsyncStridedMemCopyNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
for (const auto& kv : mod->functions) {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
Bool(merge_async_commit_queue_scope));
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
} catch (const dmlc::Error& e) {
return false;
}
if (tir::AsyncStridedMemCopyFinder::Find(lowered)) {
return false;
}
}
}
return true;
}
// Inherited from PostprocNode
Postproc Clone() const {
ObjectPtr<DisallowAsyncStridedMemCopyNode> n =
make_object<DisallowAsyncStridedMemCopyNode>(*this);
return Postproc(n);
}

bool merge_async_commit_queue_scope = true;

static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
};

Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy")
.set_body_typed(Postproc::DisallowAsyncStridedMemCopy);

} // namespace meta_schedule
} // namespace tvm
11 changes: 7 additions & 4 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ class AsyncDMALowerer : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}

// Add the current loop to the input iters mapping.
input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent));

// 3) for loop contains buffer store with single index
auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
Expand Down Expand Up @@ -156,8 +159,8 @@ class AsyncDMALowerer : public StmtExprMutator {

// Use DetectIterMap to detect whether store index is non-contiguous.
arith::Analyzer analyzer;
auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck,
&analyzer, false);
auto store_iter_map = DetectIterMap(store_index, input_iters, 1,
arith::IterMapLevel::Surjective, &analyzer, false);
if (!store_iter_map->errors.empty()) {
LOG(FATAL)
<< "Unable to lower async dma for non contiguous memory access with store index: "
Expand All @@ -173,8 +176,8 @@ class AsyncDMALowerer : public StmtExprMutator {
Array<PrimExpr> load_index = bufferloadnode->indices;

// Use DetectIterMap to detect whether load index is non-contiguous.
auto load_iter_map =
DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false);
auto load_iter_map = DetectIterMap(load_index, input_iters, 1,
arith::IterMapLevel::Surjective, &analyzer, false);
if (!load_iter_map->errors.empty()) {
LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: "
<< load_index;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

import tvm
from tvm import meta_schedule as ms
from tvm import tir
from tvm.script import tir as T
from tvm.target import Target


def _target() -> Target:
return Target("hexagon", host="llvm")


def _create_context(mod, target) -> ms.TuneContext:
ctx = ms.TuneContext(
mod=mod,
target=target,
space_generator=ms.space_generator.PostOrderApply(
sch_rules=[],
postprocs=[
ms.postproc.DisallowAsyncStridedMemCopy(),
],
mutator_probs={},
),
task_name="test",
)
return ctx


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off

@tvm.script.ir_module
class Matmul:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument


def test_postproc_disallow_async_strided_mem_copy_allows():
mod = Matmul
sch = tir.Schedule(mod, debug_mask="all")

matmul_block = sch.get_block("matmul")

loops = sch.get_loops(matmul_block)
cache_read = sch.cache_read(matmul_block, 0, "global.vtcm")

sch.compute_at(cache_read, loops[1])

sch.annotate(loops[1], "software_pipeline_stage", [0, 1])
sch.annotate(loops[1], "software_pipeline_order", [0, 1])
sch.annotate(loops[1], "software_pipeline_async_stages", [0])

ctx = _create_context(sch.mod, target=_target())
sch.mod.show()
assert ctx.space_generator.postprocs[0].apply(sch)


def test_postproc_disallow_async_strided_mem_copy_disallows():
mod = Matmul
sch = tir.Schedule(mod, debug_mask="all")

matmul_block = sch.get_block("matmul")

loops = sch.get_loops(matmul_block)
# Make it a strided mem copy.
cache_read = sch.cache_read(matmul_block, 1, "global.vtcm")

sch.compute_at(cache_read, loops[1])
sch.annotate(loops[1], "software_pipeline_stage", [0, 1])
sch.annotate(loops[1], "software_pipeline_order", [0, 1])
sch.annotate(loops[1], "software_pipeline_async_stages", [0])

sch.mod.show()
ctx = _create_context(sch.mod, target=_target())
assert not ctx.space_generator.postprocs[0].apply(sch)


if __name__ == "__main__":
test_postproc_disallow_async_strided_mem_copy_allows()
test_postproc_disallow_async_strided_mem_copy_disallows()