diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc
index 9b1dbf1a6618..b9fc056f1962 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator {
 
  private:
   Stmt SplitDeviceFunc(Stmt body, Target device_target) {
-    Array<Var> params = [&]() {
+    auto [params, buffers_to_declare] = [&]() -> std::tuple<Array<Var>, Array<Buffer>> {
       VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
       use_def(body);
 
@@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator {
         };
         return sort_key(a) < sort_key(b);
       });
-      return params;
+      return {params, use_def.undefined_buffers_};
     }();
 
     // CodeGenCPU is used for some device-side targets, such as
@@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator {
       kernel_ret_type = VoidType();
     }
 
-    GlobalVar kernel_symbol_global = var_supply_();
+    for (Buffer buf : buffers_to_declare) {
+      body = DeclBuffer(buf, std::move(body));
+    }
     PrimFunc device_func(params, body, kernel_ret_type);
     device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
                                                      {tir::attr::kNoAlias, Bool(true)},
                                                      {tir::attr::kIsGlobalFunc, Bool(true)}});
 
+    GlobalVar kernel_symbol_global = var_supply_();
     (*device_mod_)->Add(kernel_symbol_global, device_func);
     Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });
 
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index c7e90d4e7dc9..99ccc5556585 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -18,7 +18,7 @@
 import pytest
 import tvm
 import tvm.testing
-from tvm import te
+from tvm import te, tir
 from tvm.contrib.nvcc import have_fp16
 
 
@@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope():
 
     mod = _run_passes(mod)
     fdevice = mod["f_kernel"]
-    allocate = fdevice.body.body
+
+    allocate = fdevice
+    while not isinstance(allocate, tir.Allocate):
+        allocate = allocate.body
+
     assert allocate.buffer_var.type_annotation.storage_scope == "local"
-    assert fdevice.body.body.extents[0].value == 2
+    assert allocate.extents[0].value == 2
 
 
 @tvm.testing.requires_cuda
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py
index 571927dffe6e..2cfc65aae069 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -57,7 +57,7 @@ def test_thread_storage_sync():
     func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
     mod = run_passes(func)
     f = mod["test_kernel"]
-    body_list = tvm.tir.stmt_list(f.body.body.body)
+    body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
     assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))