From 9ff71f4a9fed3ec9f82b999fb8c25ef1bf6e243c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Aug 2023 16:27:24 -0500 Subject: [PATCH] [CodeGenC] Handle GlobalVar callee as internal function call (#15103) Analogous to #14901, treat GlobalVar callees as internal function calls in CodeGenC. This specific PR doesn't provide new end-to-end functionality, as the target="c" backend isn't compiled. It does lead into allowing subroutines in any target whose codegen derives from CodeGenC, which will depend on the single-module lowering flow in #14985. * [CodeGenC] Added unit tests for desired behavior * [CodeGenC] Handle GlobalVar callee as internal function call * Update CodeGenC subclasses for updated interface - Call `DeclareFunction` for each `PrimFunc`, prior to any `AddFunction` calls - Provide both `GlobalVar` and `PrimFunc` to `AddFunction` calls. * Updated CRT test to expect forward declaration * Provide forward declarations for call_extern in cmsis * Avoid duplicate forward declaration C's automatic pointer cast (e.g. `void*` to `int*`) means that use of the arguments to infer the function signature may be incorrect. If a `call_extern` refers to a function within the same module, only output a single forward declaration based on the PrimFunc's parameters, not based on the CallNode's arguments. * Updated expected ptx cuda * Cast the AOT pools to the arg type * Improved tvm::GetType for tvm_access_ptr and address_of These `Call` instances can return a `PointerType(PrimType(pointee_dtype))` rather than a `PrimType(DataType::Handle())`. * [ARM][Topi] Update micro kernels to use same argument type as caller Previously, the micro kernels for gemm, avg_pool, max_pool, and tensordot relied on C's implicit type conversions for the arguments, when the caller's argument types differ from the signature's parameter types. This works, except when the codegen has auto-generated a forward declaration based on the caller's argument types, such as during AOT, which then causes a conflicting definition. Since the codegen cannot determine the functions names from the `"pragma_import_c"` in order to suppress these forward declarations, this conflict can be more easily resolved by updating the micro kernel signatures. The three types of mismatches are below. - Use of `int` or `long` parameters, whose width may vary by compiler, instead of fixed-width types. - TIR expecting the data array's integer type to also be used as an error code's return type, rather than the micro kernels' `int32_t` error code. - Pointer conversion done during argument conversion. Type conversions are done at the start of each micro kernel, to avoid changing types that are used within the computational sections of each micro kernel. * Updated unit tests with private=True Required for internal functions after PR #15214 * Docstring updates from review --- .../mprofile/dsp/micro_kernel/avg_pool.py | 8 +- .../arm_cpu/mprofile/dsp/micro_kernel/gemm.py | 87 ++++++++-- .../mprofile/dsp/micro_kernel/max_pool.py | 13 +- .../mprofile/dsp/micro_kernel/tensordot.py | 7 +- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 28 ++-- .../example_target_hooks/tir_to_runtime.cc | 26 ++- .../backend/contrib/uma/tir_to_runtime.cc | 34 ++-- src/target/opt/build_cuda_on.cc | 18 ++- src/target/source/codegen_aocl.cc | 19 ++- src/target/source/codegen_c.cc | 153 ++++++++++++------ src/target/source/codegen_c.h | 59 ++++++- src/target/source/codegen_c_host.cc | 93 +++++------ src/target/source/codegen_c_host.h | 3 +- src/target/source/codegen_cuda.cc | 4 +- src/target/source/codegen_cuda.h | 2 +- src/target/source/codegen_metal.cc | 77 +++++---- src/target/source/codegen_metal.h | 3 +- src/target/source/codegen_opencl.cc | 24 ++- src/target/source/codegen_vhls.cc | 34 ++-- src/target/source/codegen_webgpu.cc | 79 ++++----- src/target/source/codegen_webgpu.h | 4 +- src/target/source/source_module.cc | 6 +- src/tir/op/op.cc | 26 +++ .../aot/test_crt_forward_declarations.py | 4 +- .../python/test_topi_conv2d_tensordot_opts.py | 28 +++- .../unittest/test_target_codegen_c_host.py | 48 +++++- ...est_tir_transform_inject_ptx_async_copy.py | 1 + 27 files changed, 591 insertions(+), 297 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py index e8e45152aae7..3eb32d8fdb16 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py @@ -55,7 +55,7 @@ def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - cc.dtype, + "int32", f"{func_prefix}_{width}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), @@ -68,7 +68,7 @@ def _body(): def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( - tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w")) + tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w")) ) return ib.get() @@ -113,8 +113,8 @@ def sum_impl(N, uniq_id): __attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}( int16_t *arr, int16_t *res16, - long arr_offset, - int reset) {{ + int32_t arr_offset, + int32_t reset) {{ int n; int32_t *p32; int32_t res = reset ? 0 : *res16; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py index 929dcc6557ff..e26e818fbd7e 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py @@ -156,9 +156,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}( - int K, + int32_t K_arg, int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 4) * 4; switch ( K % 4 ) {{ case 1: @@ -200,7 +205,12 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -221,7 +231,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int16_t bb_pad[{bb_pad_size}]; int32_t retcode = 0; @@ -265,9 +279,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}( - int K, + int32_t K_arg, int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 4) * 4; switch ( K % 4 ) {{ case 1: @@ -309,7 +328,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -327,7 +350,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int16_t bb_pad[{bb_pad_size}]; int32_t retcode = 0; @@ -368,9 +395,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}( - int K, + int32_t K_arg, int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 2) * 2; for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ @@ -387,7 +419,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -408,7 +444,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int32_t retcode = 0; if ( {M} < 2 && {N} < 2 ) {{ @@ -450,9 +490,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}( - int K, + int32_t K_arg, int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 2) * 2; for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ @@ -469,7 +514,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -487,7 +536,11 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int32_t retcode = 0; if ( {M} < 2 && {N} < 2 ) {{ @@ -520,7 +573,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id): #ifdef __cplusplus extern "C" #endif -__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{ +__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ cc[i*C_stride + j] = 0; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py index 66d712a4a0a2..cfed417c9fe7 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py @@ -46,7 +46,7 @@ def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - cc.dtype, + "int32", f"{func_prefix}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), @@ -59,7 +59,7 @@ def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] + "int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] ) ) return ib.get() @@ -96,7 +96,7 @@ def max_impl(uniq_id): #endif __attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}( int8_t *res, - int N) {{ + int32_t N) {{ memset(res, (int8_t)-128, N * sizeof(*res)); return 0; }} @@ -107,7 +107,9 @@ def max_impl(uniq_id): __attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}( int8_t *arg, int8_t *res, - int N) {{ + int32_t N_arg) {{ + int N = N_arg; + for ( int i = 0; i < N; ++ i ) if ( arg[i] > res[i] ) res[i] = arg[i]; @@ -120,7 +122,8 @@ def max_impl(uniq_id): __attribute__((always_inline)) static inline int32_t max8_{uniq_id}( int8_t *arg, int8_t *res, - int N) {{ + int32_t N_arg) {{ + int N = N_arg; int32_t *parg32, *pres32; int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3; int32_t retcode = 0; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index d2a8f1ef6905..af3b23e01dcb 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -390,8 +390,13 @@ def insert_lines(lines): #define {function_name.upper()}_EXISTS #include __attribute__((always_inline)) static inline int32_t {function_name}( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) {{ + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + {_init_biased_accumulators(num_outputs)} {insert_lines(load_tensor_lines)} diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index ba2aea54bb91..ea2eabd76743 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices); } - /*! - * \brief Emit code that offloads a subgraph to the Cortex-M - * - * \return string of code that offloads a subgraph to the Cortex-M - */ - void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } - private: /*! * \brief Enable storing the last error */ bool debug_last_error; @@ -519,11 +512,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { bool emit_fwd_func_decl = false; bool debug_last_error = GetCompilerAttrs()->debug_last_error; CodeGenCMSISNN codegen; - Array function_names; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error); - std::vector> funcs; - for (auto kv : mod->functions) { - funcs.push_back(kv); + + std::vector> funcs; + for (auto [gvar, base_func] : mod->functions) { + funcs.push_back({gvar, Downcast(base_func)}); } std::sort(funcs.begin(), funcs.end(), @@ -538,13 +531,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { return name_hint_a < name_hint_b; }); - for (auto kv : funcs) { - auto prim_func = Downcast(kv.second); - auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); - function_names.push_back(global_symbol.value()); - codegen.AddFunction(prim_func); + for (auto [gvar, prim_func] : funcs) { + codegen.AddFunction(gvar, prim_func); } std::string code = codegen.Finish(); + + Array function_names; + for (auto [gvar, prim_func] : funcs) { + function_names.push_back(codegen.GetFunctionName(gvar)); + } + return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc index 0db8d06c3143..6f09e0a0c3f0 100644 --- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { bool emit_asserts = false; bool emit_fwd_func_decl = false; CodeGenExampleTargetHook codegen; - Array function_names; + std::unordered_set devices; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); - for (auto kv : mod->functions) { - auto prim_func = Downcast(kv.second); - auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); - function_names.push_back(global_symbol.value()); - codegen.AddFunction(prim_func); + + Map functions; + for (auto [gvar, base_func] : mod->functions) { + auto prim_func = Downcast(base_func); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + codegen.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl); } + std::string code = codegen.Finish(); + + Array function_names; + for (auto [gvar, prim_func] : functions) { + function_names.push_back(codegen.GetFunctionName(gvar)); + } + return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc index 3b58fda54b52..487e247f5d38 100644 --- a/src/relay/backend/contrib/uma/tir_to_runtime.cc +++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc @@ -49,13 +49,6 @@ class UMACodegen : public codegen::CodeGenCHost { CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices); } - /*! - * \brief Emit code that offloads a subgraph to the UMA target - * - * \return string of code that offloads a subgraph to the UMA target - */ - void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } - private: String target_str_; }; @@ -63,17 +56,30 @@ class UMACodegen : public codegen::CodeGenCHost { runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; - bool emit_fwd_func_decl = false; + bool emit_fwd_func_decl = true; UMACodegen codegen(target->kind->name); - Array function_names; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl); - for (auto kv : mod->functions) { - auto prim_func = Downcast(kv.second); - auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); - function_names.push_back(global_symbol.value()); - codegen.AddFunction(prim_func); + + Map functions; + for (auto [gvar, base_func] : mod->functions) { + auto prim_func = Downcast(base_func); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + codegen.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl); } + std::string code = codegen.Finish(); + + Array function_names; + for (auto [gvar, prim_func] : functions) { + function_names.push_back(codegen.GetFunctionName(gvar)); + } + return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 1c0b5094efab..e0f53e350992 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -131,13 +131,21 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { CodeGenCUDA cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); } std::string code = cg.Finish(); diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 700d85b4ccd4..dc3ba0875161 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -40,13 +40,22 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) { CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); } std::string code = cg.Finish(); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a7cc320562cb..187bdc74fe29 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -42,6 +42,7 @@ void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); handle_data_type_.clear(); CodeGenSourceBase::ClearFuncState(); + ReserveKeywordsAsUnique(); } void CodeGenC::ReserveKeywordsAsUnique() { @@ -75,51 +76,92 @@ void CodeGenC::ReserveKeywordsAsUnique() { name_supply_->ReserveName("return"); } -void CodeGenC::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - // reserve keywords - ReserveKeywordsAsUnique(); +void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) { + PrintFuncPrefix(os); + PrintType(func->ret_type, os); + PrintExtraAttrs(func, os); + os << " " << function_name << "("; + for (size_t i = 0; i < func->params.size(); ++i) { + tir::Var v = func->params[i]; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - - this->PrintFuncPrefix(stream); - PrintType(f->ret_type, stream); - this->PrintExtraAttrs(f); - this->stream << " " << static_cast(global_symbol.value()) << "("; - - for (size_t i = 0; i < f->params.size(); ++i) { - tir::Var v = f->params[i]; - std::string vid = AllocVarID(v.get()); - if (i != 0) stream << ", "; - if (v.dtype().is_handle()) { - auto it = alloc_storage_scope_.find(v.get()); - if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); - } + if (i > 0) { + os << ", "; + } - PrintType(GetType(v), stream); - // Register handle data type - // TODO(tvm-team): consider simply keep type info in the - // type annotation(via a normalizing rewriting). - if (auto* ptr = v->type_annotation.as()) { - if (auto* prim = ptr->element_type.as()) { - RegisterHandleType(v.get(), prim->dtype); - } - } + if (auto it = alloc_storage_scope_.find(v.get()); it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, os); + } - if (no_alias) { - PrintRestrict(v, stream); + PrintType(GetType(v), os); + + bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); + bool is_handle = v.dtype().is_handle(); + if (no_alias && is_handle) { + PrintRestrict(v, os); + } + + os << " " << AllocVarID(v.get()); + } + os << ")"; + + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + for (const auto& param : func->params) { + if (auto* ptr = param->type_annotation.as()) { + if (auto* prim = ptr->element_type.as()) { + RegisterHandleType(param.get(), prim->dtype); } - } else { - PrintType(GetType(v), stream); } - stream << ' ' << vid; } - stream << ") {\n"; +} + +void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { + if (internal_functions_.count(gvar)) { + return; + } + + auto function_name = [&]() -> String { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = global_symbol.value(); + ICHECK(!func_name_supply_->ContainsName(name)) + << "Function " << gvar << " must use global symbol " << name + << ", but this name has already been used."; + func_name_supply_->ReserveName(name); + return name; + } else { + func_name_supply_->ReserveName(gvar->name_hint); + return gvar->name_hint; + } + }(); + + internal_functions_.insert({gvar, function_name}); + + InitFuncState(func); + PrintFunctionSignature(function_name, func, fwd_decl_stream); + fwd_decl_stream << ";\n"; +} + +String CodeGenC::GetFunctionName(const GlobalVar& gvar) { + auto it = internal_functions_.find(gvar); + ICHECK(it != internal_functions_.end()) + << "Attempted to find name of " << gvar + << ", but no function with this GlobalVar has been declared"; + return it->second; +} + +void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { + // If the function has already been forward-declared, this is a + // no-op. + DeclareFunction(gvar, f); + auto function_name = GetFunctionName(gvar); + + // clear previous generated state. + InitFuncState(f); + + PrintFunctionSignature(function_name, f, stream); + stream << " {\n"; this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); @@ -130,9 +172,15 @@ void CodeGenC::AddFunction(const PrimFunc& f) { void CodeGenC::PrintFuncPrefix(std::ostream& os) {} -void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} +void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {} -std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } +std::string CodeGenC::Finish() { + std::ostringstream code; + code << decl_stream.str(); + code << fwd_decl_stream.str(); + code << stream.str(); + return code.str(); +} void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { @@ -542,12 +590,17 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); - Array arg_types; - for (size_t i = 1; i < op->args.size(); i++) { - arg_types.push_back(GetType(op->args[i])); + + // If the call_extern refers to an function within the IRModule, then + // the forward declaration is already provided from DeclareFunction. + if (!func_name_supply_->ContainsName(func->value)) { + Array arg_types; + for (size_t i = 1; i < op->args.size(); i++) { + arg_types.push_back(GetType(op->args[i])); + } + Type ret_type = GetTypeFromRuntimeDataType(op->dtype); + this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); } - Type ret_type = GetTypeFromRuntimeDataType(op->dtype); - this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], @@ -615,9 +668,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else { LOG(FATAL) << "Unresolved call " << op->op; } + } else if (auto opt = op->op.as()) { + auto gvar = opt.value(); + auto callee_name = GetFunctionName(gvar); + PrintCallExtern(GetType(GetRef(op)), callee_name, op->args, false, os); } else { - ICHECK(op->op.as()); - LOG(FATAL) << "Do not yet support cross function call"; + LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, " + << "nor a GlobalVar reference to another function in the IRModule"; } } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 93f9ea519c23..2921a56ef3a1 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -65,12 +65,33 @@ class CodeGenC : public ExprFunctor, * \param output_ssa Whether output SSA. */ void Init(bool output_ssa); + /*! - * \brief Add the function to the generated module. - * \param f The function to be compiled. + * \brief Add the function declaration to the generated module, + * without defining it. + * + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. * \param whether to append return 0 in the end. */ - void AddFunction(const PrimFunc& f); + virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func); + + /*! + * \brief Add the function to the generated module, including its + * declaration and definition. + * + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. + */ + virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func); + + /*! + * \brief Get the name of a declared function + * \param gvar The GlobalVar of the function + * \returns The string name of the function + */ + String GetFunctionName(const GlobalVar& gvar); + /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -96,7 +117,23 @@ class CodeGenC : public ExprFunctor, PrintExpr(n, os); return os.str(); } + // The following parts are overloadable print operations. + + /*! \brief Print the function signature before the argument list + * + * The default implementation delegates out to PrintFuncPrefix and + * PrintExtraAttrs. + * + * \param function_name The name of the function + * + * \param func The function whose signature should be printed + * + * \param os The output stream + */ + virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os); + /*! * \brief Print the function header before the argument list * \param os The output stream @@ -109,7 +146,7 @@ class CodeGenC : public ExprFunctor, * * Example: __launch_bounds__(256) for CUDA functions */ - virtual void PrintExtraAttrs(const PrimFunc& f); + virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. @@ -284,10 +321,24 @@ class CodeGenC : public ExprFunctor, private: /*! \brief set of volatile buf access */ std::unordered_set volatile_buf_; + // deep comparison of PrimExpr ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; + + /* \brief Map of GlobalVar to their symbol. + * + * For externally-exposed functions, this is given by the + * tvm::attr::kTarget attribute of the PrimFunc. For internal + * functions, this is the name of the function's GlobalVar, possibly + * altered to prevent duplicate names. + */ + std::unordered_map internal_functions_; + + /* \brief Name supply to generate unique function names */ + NameSupply func_name_supply_{""}; }; } // namespace codegen diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 3255e11c5d36..caef43e8af28 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -75,19 +75,24 @@ void CodeGenCHost::InitGlobalContext() { void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } -void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; - function_names_.push_back(global_symbol.value()); +void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, + bool emit_fwd_func_decl) { + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (global_symbol) { + function_names_.push_back(global_symbol.value()); + } emit_fwd_func_decl_ = emit_fwd_func_decl; - CodeGenC::AddFunction(f); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + CodeGenC::AddFunction(gvar, func); + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + ICHECK(global_symbol.defined()) + << "CodeGenCHost: The entry func must have the global_symbol attribute, " + << "but function " << gvar << " only has attributes " << func->attrs; + function_names_.push_back(runtime::symbol::tvm_module_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); - PrintType(f->ret_type, stream); + PrintType(func->ret_type, stream); stream << " " << tvm::runtime::symbol::tvm_module_main << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " << "int* out_ret_tcode, void* resource_handle) {\n"; @@ -128,15 +133,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) << "TVM_DLL "; } -std::string CodeGenCHost::Finish() { // NOLINT(*) - std::string ret = decl_stream.str(); - if (emit_fwd_func_decl_) { - ret += fwd_decl_stream.str(); - } - ret += stream.str(); - return ret; -} - void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { @@ -437,42 +433,38 @@ runtime::Module BuildCHost(IRModule mod, Target target) { CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); cg.SetConstantsByteAlignment(target->GetAttr("constants-byte-alignment").value_or(16)); - PrimFunc aot_executor_fn; - - std::vector> funcs; - for (auto kv : mod->functions) { - // Make sure that the executor function is the last one to be code generated so that all the - // symbols are available to __tvm_main__ - auto fun_name = std::string(kv.first->name_hint); - bool is_aot_executor_fn = kv.second->GetAttr("runner_function", Bool(false)).value(); - - if (is_aot_executor_fn) { - aot_executor_fn = Downcast(kv.second); - continue; - } - funcs.push_back(kv); + + auto is_aot_executor_fn = [](const PrimFunc& func) -> bool { + return func->GetAttr("runner_function", Bool(false)).value(); + }; + + std::vector> funcs; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + funcs.push_back({gvar, prim_func}); } // Sort functions - std::sort(funcs.begin(), funcs.end(), - [](std::pair kv_a, - std::pair kv_b) { - std::string name_hint_a = kv_a.first->name_hint; - std::string name_hint_b = kv_b.first->name_hint; - return name_hint_a < name_hint_b; - }); - - // Add all functions except __tvm_main__ - for (auto& kv : funcs) { - ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; - auto f = Downcast(kv.second); - cg.AddFunction(f); + auto sort_key = [&is_aot_executor_fn](const auto& kv) { + return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; + }; + std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) { + return sort_key(kv_a) < sort_key(kv_b); + }); + + // Declare all functions first. This ensures that all functions, + // including the __tvm_main__ used in AOT, have access to forward + // declarations of other functions in the IRModule. + for (const auto& [gvar, prim_func] : funcs) { + cg.DeclareFunction(gvar, prim_func); } - // Add __tvm_main__ - if (aot_executor_fn.defined()) { - emit_fwd_func_decl = true; - cg.AddFunction(aot_executor_fn, emit_fwd_func_decl); + // Codegen all functions. Passing emit_fwd_func_decl=true adds a + // forward declaration for any `builtin::call_extern`, based on the + // arguments provided to it. + for (const auto& [gvar, prim_func] : funcs) { + cg.AddFunction(gvar, prim_func, emit_fwd_func_decl); } // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build(). @@ -484,7 +476,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) { } else { runtime = relay::Runtime::Create("cpp", {}); } - if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) { + + bool has_aot_executor_fn = std::any_of( + funcs.begin(), funcs.end(), [&](const auto& kv) { return is_aot_executor_fn(kv.second); }); + if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) { cg.InitGlobalContext(); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 694104afc0af..aeba685f7422 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,8 +44,7 @@ class CodeGenCHost : public CodeGenC { const std::unordered_set& devices); void InitGlobalContext(); - void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false); - std::string Finish() final; + void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl = false); /*! * \brief Add functions from the (unordered) range to the current module in a deterministic * order. This helps with debugging. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 22103f7b0ff2..6c0234819199 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor { PrimExpr threadIdx_z_ext = Integer(1); }; -void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { +void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { ThreadIdxExtractor extractor; extractor(f->body); arith::Analyzer analyzer; @@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { // unable to extract the number of threads per block, hence directly return return; } - stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + os << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; } } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index c6cf96d460d4..7de6ae05e87d 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC { } // override behavior void PrintFuncPrefix(std::ostream& os) final; - void PrintExtraAttrs(const PrimFunc& f) final; + void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b8c30691e21f..3db8d216b3b1 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -36,6 +36,8 @@ namespace codegen { void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -52,37 +54,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) { << "};\n\n"; } -void CodeGenMetal::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("v_"); - +void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) { // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; + os << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; size_t limit = target_->GetAttr("max_function_args").value().IntValue(); - if (f->params.size() > limit) { + if (func->params.size() > limit) { LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " "buffers in the kernel"; } - for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { - Var v = f->params[i]; + for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { + Var v = func->params[i]; if (!v.dtype().is_handle()) break; - stream << " "; + os << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); + PrintStorageScope(it->second, os); } - PrintType(GetType(v), stream); + PrintType(GetType(v), os); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -91,19 +89,18 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { RegisterHandleType(v.get(), prim->dtype); } } - stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + os << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. - size_t nargs = f->params.size() - num_buffer; + size_t nargs = func->params.size() - num_buffer; std::string varg = name_supply_->FreshName("arg"); if (nargs != 0) { std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; - stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer - << ") ]],\n"; + os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; - for (size_t i = num_buffer; i < f->params.size(); ++i) { - Var v = f->params[i]; + for (size_t i = num_buffer; i < func->params.size(); ++i) { + Var v = func->params[i]; ICHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; @@ -131,7 +128,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto launch_params = f->GetAttr>(tir::attr::kKernelLaunchParams).value(); + auto launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams).value(); for (const auto& tag : launch_params) { if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); @@ -150,13 +147,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { } thread_work_dim_ = work_dim; - // the function scope. - stream << ") {\n"; - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->EndScope(func_scope); - this->PrintIndent(); - this->stream << "}\n\n"; + stream << ")"; } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { @@ -342,27 +333,33 @@ runtime::Module BuildMetal(IRModule mod, Target target) { const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); std::string fmt = fmetal_compile ? "metallib" : "metal"; - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); - std::string func_name = global_symbol.value(); + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + auto calling_conv = base_func->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + + auto prim_func = Downcast(base_func); + functions.Set(gvar, prim_func); + } - source_maker << "// Function: " << func_name << "\n"; + for (auto [gvar, prim_func] : functions) { + source_maker << "// Function: " << gvar->name_hint << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); + std::string fsource = cg.Finish(); source_maker << fsource << "\n"; if (fmetal_compile) { fsource = (*fmetal_compile)(fsource, target).operator std::string(); } - smap[func_name] = fsource; + smap[cg.GetFunctionName(gvar)] = fsource; } return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 36be10d16363..26c991e60df9 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -38,7 +38,8 @@ class CodeGenMetal final : public CodeGenC { explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); - void AddFunction(const PrimFunc& f); // NOLINT(*) + void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) override; void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index c15d2253d716..da6a4de6196a 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -595,18 +595,26 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + functions.Set(gvar, prim_func); + } + std::stringstream code; const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc"); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; - code << "// Function: " << kv.first->name_hint << std::endl; + for (auto [gvar, prim_func] : functions) { + code << "// Function: " << gvar->name_hint << std::endl; CodeGenOpenCL cg; cg.Init(output_ssa); - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); std::string fsource = cg.Finish(); if (fpostproc) { fsource = (*fpostproc)(fsource, target).operator std::string(); diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 83046de10701..aa7a32320c5e 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -145,13 +145,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { // Generate source code for get_source(). cg.Init(output_ssa); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); } std::string whole_code = cg.Finish(); @@ -159,21 +167,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { // Generate source code for compilation. Array> kernel_info; - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; - auto f = Downcast(kv.second); + for (auto [gvar, prim_func] : functions) { CodeGenVivadoHLS cg; cg.Init(output_ssa); - cg.AddFunction(f); + + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); std::string code = cg.Finish(); if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { code = (*f)(code, target).operator std::string(); } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - kernel_info.push_back({global_symbol.value(), code}); + auto function_name = cg.GetFunctionName(gvar); + kernel_info.push_back({function_name, code}); } std::string xclbin; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 4d1d834c7fac..6a6712a4ce26 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -45,6 +45,12 @@ std::string CodeGenWebGPU::Finish() { void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); + // skip the first underscore, so SSA variable starts from + name_supply_->FreshName("v_"); + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -56,28 +62,12 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} -void CodeGenWebGPU::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - // skip the first underscore, so SSA variable starts from - name_supply_->FreshName("v_"); - // Setup the thread group info. - ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); - ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); - - // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; - - decl_stream << "//----------------------------------------\n" - << "// function: " << global_symbol.value() << "\n" - << "//----------------------------------------\n"; - +void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) { std::vector pod_args; int num_buffer = 0; // setup buffer argumemts - for (Var arg : f->params) { + for (Var arg : func->params) { DataType t = arg.dtype(); if (t.is_handle()) { auto* ptr = arg->type_annotation.as(); @@ -111,16 +101,18 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { } // add to alloc buffer type. // Function header. - this->stream << "fn main(\n" - << " @builtin(workgroup_id) blockIdx : vec3,\n" - << " @builtin(local_invocation_id) threadIdx : vec3\n" - << ") {\n"; - // the function scope. - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->EndScope(func_scope); - this->PrintIndent(); - this->stream << "}\n\n"; + os << "fn main(\n" + << " @builtin(workgroup_id) blockIdx : vec3,\n" + << " @builtin(local_invocation_id) threadIdx : vec3\n" + << ")"; +} + +void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { + CodeGenC::AddFunction(gvar, func); + decl_stream << "//----------------------------------------\n" + << "// function: " << GetFunctionName(gvar) << "\n" + << "//----------------------------------------\n"; + // anotate workgroup this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", " << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n"; @@ -524,22 +516,31 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; - std::unordered_map smap; - for (auto kv : mod->functions) { - CodeGenWebGPU cg(target); - ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol.value(); + functions.Set(gvar, prim_func); + } + + std::unordered_map smap; + for (auto [gvar, prim_func] : functions) { + CodeGenWebGPU cg(target); cg.Init(output_ssa); - cg.AddFunction(f); + + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); + std::string code = cg.Finish(); - smap[f_name] = code; + smap[cg.GetFunctionName(gvar)] = code; } auto n = make_object(smap, ExtractFuncInfo(mod)); return runtime::Module(n); diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index 57f226ba8ad6..6ae942a3ad49 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -48,7 +48,9 @@ class CodeGenWebGPU final : public CodeGenC { explicit CodeGenWebGPU(Target target); // overrides std::string Finish() final; - void AddFunction(const PrimFunc& f); // NOLINT(*) + void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) final; + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final; void InitFuncState(const PrimFunc& f) final; void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index be5179e081a1..c75f3008ef6b 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -574,12 +574,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } for (const tir::Var& pool_var : metadata_->pools) { + call_args_ss << "((uint8_t*)"; String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name; if (IsInternalWorkspaceBuffer(pool_var)) { - call_args_ss << "&" << pool_name << ","; + call_args_ss << "&" << pool_name; } else { - call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ","; + call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name); } + call_args_ss << "),"; } for (const String& device : metadata_->devices) { call_args_ss << "devices->" << device << ","; diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 39214c4546dc..fd14f4892154 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) { return ptr->type_annotation; } } + + if (auto* access = expr.as()) { + if (access->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments"; + auto type_annotation = Downcast(access->args[0]); + static auto builtin_op = Op::Get("tir.type_annotation"); + ICHECK(type_annotation->op.same_as(builtin_op)) + << "Expected the first argument of builtin tvm_access_ptr() " + << "to be a type annotation, but found " << type_annotation->op; + return PointerType(PrimType(type_annotation->dtype)); + } + } + + if (auto* address_of = expr.as()) { + if (address_of->op.same_as(builtin::address_of())) { + ICHECK_EQ(address_of->args.size(), 1) + << "Builtin address_of() expects a single argument, but received arguments " + << address_of->args; + auto* address = address_of->args[0].as(); + ICHECK(address) + << "Builtin address_of() expects the argument to be a BufferLoad, but received argument " + << address_of->args[0]; + + return PointerType(PrimType(address->dtype)); + } + } // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); return GetTypeFromRuntimeDataType(dtype); diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py index 17af7a5d682d..0c73f18e8a4c 100644 --- a/tests/python/relay/aot/test_crt_forward_declarations.py +++ b/tests/python/relay/aot/test_crt_forward_declarations.py @@ -160,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner): lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] main_source = lib_mod.get_source() - assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1 - assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3 + assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2 + assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 6 @tvm.testing.requires_corstone300 diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py index 7bea7577b6bf..f6145cd1c51a 100644 --- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py +++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py @@ -135,8 +135,13 @@ def test_write_3x3_depthwise_code(): #define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w48_3x3_000( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -188,8 +193,13 @@ def test_odd_width_3x3_depthwise_strides_code(): #define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x2_int16_w49_3x3_000_2_4( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias, sum_1 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -251,8 +261,13 @@ def test_1x1x8_convolution_code(): #define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x4_int16_w384_1x8_000_8_1( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -349,8 +364,13 @@ def test_3x3x3_offset_convolution_code(): #define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS #include __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w288_3x9_111( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias; int32_t tensor__unknown__y00_x00 = tensor[0]; diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index d02f8744f129..3aca0fc8c77e 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import tvm import tvm.testing + from tvm import te -import numpy as np from tvm.contrib import utils +from tvm.script import tir as T, ir as I + +import numpy as np def test_add(): @@ -228,11 +232,39 @@ def check_global_packed_func(): check_global_packed_func() +def test_subroutine_call(): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, dtype="float32")): + mod.subroutine(A.data) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32")): + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 42.0 + + built = tvm.build(mod, target="c") + + func_names = list(built["get_func_names"]()) + assert ( + "main" in func_names + ), "Externally exposed functions should be listed in available functions." + assert ( + "subroutine" not in func_names + ), "Internal function should not be listed in available functions." + + source = built.get_source() + assert ( + source.count("main(void*") == 2 + ), "Expected two occurrences, for forward-declaration and definition" + assert ( + source.count("subroutine(float*") == 2 + ), "Expected two occurrences, for forward-declaration and definition" + assert ( + source.count("subroutine(") == 3 + ), "Expected three occurrences, for forward-declaration, definition, and call from main." + + if __name__ == "__main__": - test_add() - test_add_pipeline() - test_reinterpret() - test_ceil() - test_floor() - test_round() - test_call_packed() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 3543f798c36e..b39fca72c871 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -204,6 +204,7 @@ def test_inject_async_copy_shared_dyn(): #define int64_t long long #define uint64_t unsigned long long #endif +extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64];