Skip to content

Commit

Permalink
[TIR] Output DeclBuffer in SplitHostDevice (#15493)
Browse files Browse the repository at this point in the history
* [TIR] Output DeclBuffer in SplitHostDevice

If the generated device function uses a buffer, generate a DeclBuffer
for the buffer at the top of the device function.

This is a subset of the changes made in
#14778, broken out for ease of
testing and review.

* Updated thread sync test to account for DeclBuffer

* Updated LowerWarp unit tests to find Allocate in PrimFunc
  • Loading branch information
Lunderberg authored Aug 28, 2023
1 parent 8f60213 commit c921781
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator {

private:
Stmt SplitDeviceFunc(Stmt body, Target device_target) {
Array<Var> params = [&]() {
auto [params, buffers_to_declare] = [&]() -> std::tuple<Array<Var>, Array<Buffer>> {
VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
use_def(body);

Expand All @@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator {
};
return sort_key(a) < sort_key(b);
});
return params;
return {params, use_def.undefined_buffers_};
}();

// CodeGenCPU is used for some device-side targets, such as
Expand All @@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator {
kernel_ret_type = VoidType();
}

GlobalVar kernel_symbol_global = var_supply_();
for (Buffer buf : buffers_to_declare) {
body = DeclBuffer(buf, std::move(body));
}
PrimFunc device_func(params, body, kernel_ret_type);
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, Bool(true)},
{tir::attr::kIsGlobalFunc, Bool(true)}});

GlobalVar kernel_symbol_global = var_supply_();
(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });

Expand Down
10 changes: 7 additions & 3 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import tvm
import tvm.testing
from tvm import te
from tvm import te, tir
from tvm.contrib.nvcc import have_fp16


Expand Down Expand Up @@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope():

mod = _run_passes(mod)
fdevice = mod["f_kernel"]
allocate = fdevice.body.body

allocate = fdevice
while not isinstance(allocate, tir.Allocate):
allocate = allocate.body

assert allocate.buffer_var.type_annotation.storage_scope == "local"
assert fdevice.body.body.extents[0].value == 2
assert allocate.extents[0].value == 2


@tvm.testing.requires_cuda
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_transform_thread_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_thread_storage_sync():
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = run_passes(func)
f = mod["test_kernel"]
body_list = tvm.tir.stmt_list(f.body.body.body)
body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))


Expand Down

0 comments on commit c921781

Please sign in to comment.