diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9b1dbf1a6618..b9fc056f1962 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator { private: Stmt SplitDeviceFunc(Stmt body, Target device_target) { - Array<Var> params = [&]() { + auto [params, buffers_to_declare] = [&]() -> std::tuple<Array<Var>, Array<Buffer>> { VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); use_def(body); @@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator { }; return sort_key(a) < sort_key(b); }); - return params; + return {params, use_def.undefined_buffers_}; }(); // CodeGenCPU is used for some device-side targets, such as @@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator { kernel_ret_type = VoidType(); } - GlobalVar kernel_symbol_global = var_supply_(); + for (Buffer buf : buffers_to_declare) { + body = DeclBuffer(buf, std::move(body)); + } PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, {tir::attr::kIsGlobalFunc, Bool(true)}}); + GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; }); diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index c7e90d4e7dc9..99ccc5556585 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -18,7 +18,7 @@ import pytest import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm.contrib.nvcc import have_fp16 @@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope(): mod = _run_passes(mod) fdevice = mod["f_kernel"] - allocate = fdevice.body.body + + allocate = fdevice + while not isinstance(allocate, tir.Allocate): + allocate = allocate.body + assert allocate.buffer_var.type_annotation.storage_scope == "local" - assert fdevice.body.body.extents[0].value == 2 + assert allocate.extents[0].value == 2 @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 571927dffe6e..2cfc65aae069 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -57,7 +57,7 @@ def test_thread_storage_sync(): func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) mod = run_passes(func) f = mod["test_kernel"] - body_list = tvm.tir.stmt_list(f.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))