From 521c16d3eaf1309e0a421ebef1ebd40675f5d3c9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 11:03:59 -0500 Subject: [PATCH 01/16] PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript Prior to this commit, the `TargetNode::host` could be specified in TVMScript as part of the config dictionary, under the key `"host"`. However, this required all other device parameters to be explicitly specified, rather than using any of the short-hand string representations. This commit forwards the `host` argument from TVMScript's `T.target` method to `tvm.target.Target`, allowing both the device and host to be specified using the shorthand string representation. ```python @T.prim_func def before_this_commit(): T.func_attr( { "target": T.target( { "arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32, } ) } ) T.evaluate(0) @T.prim_func def after_this_commit(): T.func_attr({"target": T.target("cuda", host="llvm")}) T.evaluate(0) ``` --- python/tvm/script/ir_builder/tir/ir.py | 22 +++++++++++++++++-- .../unittest/test_tvmscript_roundtrip.py | 10 +++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c8285ccc52ce..a6a21ea9402a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1655,7 +1655,10 @@ def index_map( return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) -def target(target_config: Union[Dict, str]) -> Target: +def target( + target_config: Union[Dict, str], + host: Optional[Union[Dict, str, Target]] = None, +) -> Target: """ Create a target @@ -1664,6 +1667,9 @@ def target(target_config: Union[Dict, str]) -> Target: target_config : Union[Dict, str] The target configuration. + host : Optional[Union[Dict, str, Target]] + The target configuration. + Returns ------- res : Target @@ -1673,7 +1679,19 @@ def target(target_config: Union[Dict, str]) -> Target: raise ValueError( f"T.target expected a config dict or string, but got {type(target_config)}" ) - return Target(target_config) + if host is not None and not isinstance(host, (str, dict, Target)): + raise ValueError( + "T.target expected the host to be " + "a config dict, string, or T.target, " + f"but got {type(host)}" + ) + if isinstance(target_config, dict) and "host" in target_config and host is not None: + raise ValueError( + "T.target expects to either receive the host " + "as part of the target's config dictionary, " + "or as a separate argument, but not both." + ) + return Target(target_config, host) def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 2ea7d3ec6579..e3ec311cc0c5 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3123,6 +3123,15 @@ def func_with_target_spec_by_str() -> None: return func_with_target_spec_by_str +def func_with_target_and_host_spec_by_str(): + @T.prim_func + def func(): + T.func_attr({"target": T.target("nvidia/nvidia-a100", host="llvm")}) + T.evaluate(0) + + return func + + def func_root_attr(): @T.prim_func def func_root_attr(): @@ -3883,6 +3892,7 @@ def func(): nontrivial_range_axis, func_with_target_spec_by_config, func_with_target_spec_by_str, + func_with_target_and_host_spec_by_str, func_root_attr, func_trivial_root_block, func_nested_root_block, From dfda99a67d8f74c2bb36f87f2859276db33ab6de Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Mar 2023 13:48:31 -0500 Subject: [PATCH 02/16] [Target] Added WithoutHost method --- include/tvm/target/target.h | 3 +++ src/target/target.cc | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 891700b86a4c..5c88807682d7 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -218,6 +218,9 @@ class Target : public ObjectRef { */ static Target WithHost(const Target& target, const Target& host); + /*! \return The target with the host stripped out */ + Target WithoutHost() const; + /*! * \brief Returns true if \p this target represents an external codegen. If so, * \p this->kind->name can be used as the "Compiler" attribute on partitioned functions, diff --git a/src/target/target.cc b/src/target/target.cc index f05d4db2b888..e479f592c640 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -662,6 +662,16 @@ Map TargetNode::Export() const { Optional TargetNode::GetHost() const { return this->host.as(); } +Target Target::WithoutHost() const { + if ((*this)->GetHost()) { + auto output = make_object(*get()); + output->host = NullOpt; + return Target(output); + } else { + return *this; + } +} + int TargetNode::GetTargetDeviceType() const { if (Optional device_type = GetAttr("target_device_type")) { return Downcast(device_type)->value; From b5792cefc43d9c1f71e4fdc9a7f1e65e6dabbfbc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Mar 2023 08:55:51 -0500 Subject: [PATCH 03/16] [TIR] SplitHostDevice, handle missing kGlobalSymbol Previously, the symbol name of the extracted compute kernel was defined based on the `kGlobalSymbol` attribute, which was required to be present. This commit updates `SplitHostDevice` to generate the symbol name using `kGlobalSymbol` if present, and to fall back to the name of the `tvm::GlobalVar` for internal functions. --- src/tir/transforms/split_host_device.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 4f47b8ce2bf9..3ae114f31877 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -249,15 +249,13 @@ class HostDeviceSplitter : public StmtMutator { std::unordered_map handle_data_type_; }; -PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { +PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod, const GlobalVar& gvar) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; + auto name_prefix = global_symbol.value_or(gvar->name_hint); - HostDeviceSplitter splitter(device_mod, target.value(), - static_cast(global_symbol.value())); + HostDeviceSplitter splitter(device_mod, target.value(), name_prefix); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); @@ -275,10 +273,12 @@ Pass SplitHostDevice() { IRModule device_mod = IRModule(Map({})); for (auto& kv : *func_dict) { - if (kv.second->IsInstance()) { - PrimFunc func = Downcast(std::move(kv.second)); + auto gvar = Downcast(kv.first); + auto& base_func = kv.second; + if (base_func->IsInstance()) { + PrimFunc func = Downcast(std::move(base_func)); ICHECK(device_mod.defined()) << "The device module must be defined."; - kv.second = SplitHostDevice(std::move(func), &device_mod); + base_func = SplitHostDevice(std::move(func), &device_mod, gvar); } } mod->Update(device_mod); From ec2f8be9bd1d0e096f77fa241a21d32dffe239ac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 14:52:42 -0500 Subject: [PATCH 04/16] [TIR] Refactor SplitHostDevice into three separate passes First pass, `AnnotateDeviceRegions`. This pass decides which portions of a PrimFunc should be run on the device, and annotates them with `kTarget` attribute, indicating which target should be used for later lowering steps. Second pass, `SplitHostDevice`. This pass extracts the annotated region into an independent PrimFunc. The `kTarget` attribute of the extracted kernel is defined by the `kTarget` annotation inserted by `AnnotateDeviceRegions`. The host function is marked by the `tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized by later host-only lowering passes. Third pass, `LowerDeviceKernelLaunch`. This pass identifies subroutine calls that call into device kernels, and rewrites them into `T.tvm_call_packed`. --- include/tvm/tir/transform.h | 38 +++ python/tvm/tir/op.py | 2 +- python/tvm/tir/transform/transform.py | 38 +++ src/driver/driver_api.cc | 3 + src/tir/transforms/annotate_device_regions.cc | 81 +++++ .../transforms/lower_device_kernel_launch.cc | 282 ++++++++++++++++++ src/tir/transforms/split_host_device.cc | 254 +++++----------- .../test_tir_transform_split_host_device.py | 31 +- 8 files changed, 547 insertions(+), 182 deletions(-) create mode 100644 src/tir/transforms/annotate_device_regions.cc create mode 100644 src/tir/transforms/lower_device_kernel_launch.cc diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 8dee176277d7..d9d68e0a8b6a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes(); */ TVM_DLL Pass DecorateDeviceScope(); +/*! + * \brief Annotate locations that should be run on the device + * + * Insert `AttrStmt` nodes specifying a target on which regions within + * the PrimFunc should be executed. Only modifies functions that have + * a `tvm::attr::kTarget` attribute, and where that target defines a + * host. + * + * \return The pass. + */ +TVM_DLL Pass AnnotateDeviceRegions(); + /*! * \brief Split the function into a host function and device functions. * + * The resulting host-side function will keep the same + * `tvm::attr::kTarget` attribute (e.g. `T.target("cuda", + * host=T.target("llvm"))`). This ensures that `MakePackedAPI` knows + * which device type should be used for the input buffers. + * + * The resulting device-side function will + * have the host stripped from its target attribute + * (e.g. `T.target("cuda")`). + * * \return The pass. */ TVM_DLL Pass SplitHostDevice(); +/*! + * \brief Lower cross-device function calls. + * + * Prior to this pass, host to device calls are represented as + * subroutine calls, with environment parameters (e.g. env_thread) + * specified internally. The device function is an internal function, + * without a `tvm::attr::kGlobalSymbol` attribute. + * + * After this pass, host to device calls are represented as + * tvm_call_packed built-in. The device function is an + * externally-exposed function, with a non-empty + * `tvm::attr::kGlobalSymbol` attribute. + * + * \return The pass. + */ +TVM_DLL Pass LowerDeviceKernelLaunch(); + /*! * \brief skip assert stmt. * diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 90e3db4cb96b..098c13f04e9d 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): The call expression. """ assert isinstance(global_var, tvm.ir.GlobalVar) - return Call(dtype="handle", op=global_var, args=args) + return Call(dtype="void", op=global_var, args=args) def start_profile_intrinsic(id): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f2ce4378141e..9e038f618bc3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -435,6 +435,22 @@ def MakeUnpackedAPI(): return _ffi_api.MakeUnpackedAPI() # type: ignore +def AnnotateDeviceRegions(): + """Annotate locations that should be run on the device + + Insert `AttrStmt` nodes specifying a target on which regions + within the PrimFunc should be executed. Only modifies functions + that have a `tvm::attr::kTarget` attribute, and where that target + defines a host. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateDeviceRegions() # type: ignore + + def SplitHostDevice(): """Split the function into a host function and device functions. @@ -446,6 +462,28 @@ def SplitHostDevice(): return _ffi_api.SplitHostDevice() # type: ignore +def LowerDeviceKernelLaunch(): + """Lower cross-device function calls. + + Prior to this pass, host to device calls are represented as + subroutine calls, with environment parameters (e.g. env_thread) + specified internally. The device function is an internal + function, without a `tvm::attr::kGlobalSymbol` attribute. + + After this pass, host to device calls are represented as + tvm_call_packed built-in. The device function is an + externally-exposed function, with a non-empty + `tvm::attr::kGlobalSymbol` attribute. + + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerDeviceKernelLaunch() # type: ignore + + def DecorateDeviceScope(): """Decorate all the function's body as device function. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 91bc57ccbeb2..e5f71c38320d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); + + mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); } diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc new file mode 100644 index 000000000000..a81af7d7805b --- /dev/null +++ b/src/tir/transforms/annotate_device_regions.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file annotate_device_regions.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class DeviceRegionAnnotater : public StmtMutator { + public: + explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tvm::attr::kTarget) { + // If a target attribute already exists, use it as-is. + return GetRef(op); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || + op->attr_key == attr::device_scope) { + // These attributes are only allowed in device-side code, so + // they should be annotated with the function's default target. + Stmt body = GetRef(op); + return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); + } else { + // All other annotations are ignored + return StmtMutator::VisitStmt_(op); + } + } + + private: + Target device_target_; +}; + +namespace transform { + +Pass AnnotateDeviceRegions() { + auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { + auto opt_target = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; + Target target = opt_target.value(); + + if (target->GetHost()) { + DeviceRegionAnnotater mutator(target.WithoutHost()); + func.CopyOnWrite()->body = mutator(func->body); + } + return func; + }; + + return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc new file mode 100644 index 000000000000..3a089ee57e94 --- /dev/null +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_device_kernel_launch.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +namespace { +struct DeviceInfo { + Target target; + Array params; + Optional> launch_params; + Map thread_extent; + Optional dyn_shmem_size{NullOpt}; + + PrimExpr GetArgument(const String& launch_param) const { + if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { + CHECK(dyn_shmem_size.defined()) + << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc did not contain Allocate node with shared dynamic scope."; + return dyn_shmem_size.value(); + } + + auto extent = thread_extent.Get(launch_param); + CHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent + << "\" defining this thread extent"; + return extent.value(); + } +}; + +/*! + * \brief Visitor class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { + public: + static DeviceInfo Collect(const PrimFuncNode* func) { + DeviceInfoCollector collector; + collector.info_.target = [&]() -> Target { + auto target_attr = func->GetAttr(tvm::attr::kTarget).value(); + bool is_host_func = + func->GetAttr(tvm::tir::attr::kIsHostFunc).value_or(Bool(false))->value; + if (is_host_func) { + return target_attr->GetHost().value(); + } else { + return target_attr.WithoutHost(); + } + }(); + collector.info_.params = func->params; + collector.info_.launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams); + collector(func->body); + return collector.info_; + } + + private: + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!defined_thread.count(iv.get())) { + defined_thread.insert(iv.get()); + info_.thread_extent.Set(iv->thread_tag, op->value); + } + } + + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocateNode* op) final { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK(!info_.dyn_shmem_size.defined()) + << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + + PrimExpr dyn_size = Integer(1); + for (const auto& extent : op->extents) { + dyn_size *= extent; + } + dyn_size *= op->dtype.bytes(); + + info_.dyn_shmem_size = dyn_size; + } + StmtVisitor::VisitStmt_(op); + } + + DeviceInfo info_; + // recording what thread axis have been visited. + std::unordered_set defined_thread; +}; +} // namespace + +class DeviceKernelMutator : public StmtExprMutator { + public: + using Parent = StmtExprMutator; + + explicit DeviceKernelMutator(std::unordered_map device_info_map) + : device_info_map_(std::move(device_info_map)) {} + + PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) { + ICHECK(!current_target_.defined()); + auto it = device_info_map_.find(gvar.get()); + ICHECK(it != device_info_map_.end()); + current_target_ = it->second.target; + + auto body = VisitStmt(func->body); + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + + current_target_ = NullOpt; + return func; + } + + PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { + if (device_kernel_launch_.count(gvar.get())) { + Map new_attrs; + new_attrs.Set(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)); + new_attrs.Set(tvm::tir::attr::kIsGlobalFunc, Bool(true)); + + if (!func->GetAttr(tvm::attr::kGlobalSymbol)) { + new_attrs.Set(tvm::attr::kGlobalSymbol, gvar->name_hint); + } + + func = WithAttrs(std::move(func), new_attrs); + } + + return func; + } + + private: + PrimExpr VisitExpr_(const CallNode* op) { + auto node = Downcast(Parent::VisitExpr_(op)); + + auto* gvar = op->op.as(); + if (!gvar) return std::move(node); + + auto it = device_info_map_.find(gvar); + ICHECK(it != device_info_map_.end()) + << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " + << gvar->name_hint << " did not appear within the IRModule"; + const DeviceInfo& dev_info = it->second; + + auto caller_device_type = current_target_.value()->GetTargetDeviceType(); + auto callee_device_type = dev_info.target->GetTargetDeviceType(); + if (caller_device_type == callee_device_type) { + return std::move(node); + } + + ICHECK(dev_info.launch_params.defined()) + << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " + << dev_info.target << ", but subroutine " << gvar->name_hint + << " did not have the tir::attr::kKernelLaunchParams attribute " + << "required for cross-target kernel launch"; + + // Collected kernel information may be in terms of the callee's + // arguments, but we need expressions for them in terms of the + // caller's parameters. The param_map allows substitution of + // parameter values into the thread extents, to generate + // expressions that are valid within the caller. + Map param_map = [&]() { + Map param_map; + CHECK_EQ(node->args.size(), dev_info.params.size()) + << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() + << " arguments as input, but is called using " << node->args.size() << " arguments"; + for (size_t i = 0; i < node->args.size(); i++) { + param_map.Set(dev_info.params[i], node->args[i]); + } + return param_map; + }(); + + device_kernel_launch_.insert(gvar); + + Array call_args; + call_args.push_back(StringImm(gvar->name_hint)); + for (PrimExpr arg : node->args) { + call_args.push_back(arg); + } + for (const auto& launch_param : dev_info.launch_params.value()) { + call_args.push_back(Substitute(dev_info.GetArgument(launch_param), param_map)); + } + + auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; + + return Call(dtype, builtin::tvm_call_packed(), call_args); + } + + Optional current_target_; + std::unordered_map device_info_map_; + std::unordered_set device_kernel_launch_; +}; + +namespace transform { + +Pass LowerDeviceKernelLaunch() { + auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { + auto mutator = [&mod]() { + std::unordered_map device_info_map; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* prim_func = base_func.as()) { + device_info_map[gvar.get()] = DeviceInfoCollector::Collect(prim_func); + } + } + return DeviceKernelMutator(std::move(device_info_map)); + }(); + + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + auto prim_func = mutator.RewriteKernelLaunchSite(gvar, GetRef(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + auto prim_func = mutator.UpdateKernelAttributes(gvar, GetRef(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + + return mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") + .set_body_typed(LowerDeviceKernelLaunch); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 3ae114f31877..89f17b3ca84c 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -44,22 +44,21 @@ namespace tir { /*! * \brief Visitor class to collect device-side program information. */ -class DeviceInfoCollector : public StmtVisitor { +class LaunchParamsAnnotator : public StmtVisitor { public: - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; + static PrimFunc Apply(PrimFunc func) { + LaunchParamsAnnotator collector; + collector(func->body); + return WithAttr(std::move(func), tir::attr::kKernelLaunchParams, collector.GetLaunchParams()); + } Array GetLaunchParams() const { - Array output; - for (const auto& axis : thread_axis_) { - output.push_back(axis->thread_tag); - } - if (use_dyn_shmem_) { - output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + Array launch_params = threads_; + + if (uses_dyn_shmem_) { + launch_params.push_back(tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); } - return output; + return launch_params; } private: @@ -71,216 +70,125 @@ class DeviceInfoCollector : public StmtVisitor { // use the first appearance as def. if (!defined_thread.count(iv.get())) { defined_thread.insert(iv.get()); - thread_axis_.push_back(iv); - thread_extent_.push_back(op->value); + threads_.push_back(iv->thread_tag); } - - this->VisitExpr(op->value); - this->VisitStmt(op->body); - } else { - StmtVisitor::VisitStmt_(op); } + + StmtVisitor::VisitStmt_(op); } void VisitStmt_(const AllocateNode* op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK(!uses_dyn_shmem_) << "Only one dynamic shared memory allocation is allowed."; ICHECK_GT(op->extents.size(), 0); - dyn_shmem_size_ = op->extents[0]; - for (size_t i = 1; i < op->extents.size(); ++i) { - dyn_shmem_size_ *= op->extents[i]; - } - dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); - use_dyn_shmem_ = true; + + uses_dyn_shmem_ = true; } StmtVisitor::VisitStmt_(op); } + Array threads_; + bool uses_dyn_shmem_{false}; // recording what thread axis have been visited. std::unordered_set defined_thread; }; -/*! - * \brief Mutator class to remove unrefenced let stmt/expressions. - * \param use_count The pre-computed variable to use count map. - */ -class UnreferencedLetRemover : public StmtExprMutator { - public: - explicit UnreferencedLetRemover(const std::unordered_map& use_count) - : use_count_(use_count) {} - - private: - Stmt VisitStmt_(const LetStmtNode* op) final { - Stmt body = this->VisitStmt(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { - return body; - } else { - PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return LetStmt(op->var, value, body); - } - } - } - - PrimExpr VisitExpr_(const LetNode* op) final { - PrimExpr body = this->VisitExpr(op->body); - PrimExpr value = this->VisitExpr(op->value); - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { - return body; - } else { - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return Let(op->var, value, body); - } - } - } - - // pre-computed variable to use count map. - const std::unordered_map& use_count_; -}; - class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) - : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {} - - Stmt VisitStmt_(const AllocateNode* op) final { - handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); - return StmtMutator::VisitStmt_(op); - } + explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix) + : device_mod_(device_mod), name_prefix_(name_prefix) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || - op->attr_key == attr::device_scope) { - return SplitDeviceFunc(GetRef(op)); + if (op->attr_key == tvm::attr::kTarget) { + auto device_target = op->node.as().value().WithoutHost(); + return SplitDeviceFunc(op->body, device_target); } return StmtMutator::VisitStmt_(op); } private: - Stmt SplitDeviceFunc(Stmt body) { - std::ostringstream os; - os << name_prefix_ << "_kernel" << device_func_counter_++; - std::string kernel_symbol = os.str(); - // isolate the device function. - VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); - use_def(body); - DeviceInfoCollector dev_info; - dev_info(body); - UnreferencedLetRemover let_remover(use_def.use_count_); - body = let_remover(std::move(body)); - - Array params; - Array arguments; - Map remap_vars; - - // Strictly order the arguments: Var pointers, positional arguments. - for (Var var : use_def.undefined_) { - if (var.dtype().is_handle()) { - // Create a new version of v. - auto it = handle_data_type_.find(var.get()); - if (it != handle_data_type_.end()) { - String storage_scope; - if (auto* ptr_type = var->type_annotation.as()) { - storage_scope = ptr_type->storage_scope; - } - tir::Var new_var(var->name_hint, - PointerType(PrimType((*it).second->dtype), storage_scope)); - params.push_back(new_var); - remap_vars.Set(var, new_var); - } else { - params.push_back(var); - } - arguments.push_back(var); - } - } - // positional arguments - for (Var var : use_def.undefined_) { - if (!var.dtype().is_handle()) { - params.push_back(var); - arguments.push_back(var); - } - } - GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_); - GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); - - PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = WithAttr(std::move(device_func), tir::attr::kKernelLaunchParams, - dev_info.GetLaunchParams()); - - device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, - Integer(CallingConv::kDeviceKernelLaunch)); - device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, - runtime::String(kernel_symbol_global->name_hint)); - device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); - device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); - device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); + Stmt SplitDeviceFunc(Stmt body, Target device_target) { + Array params = [&]() { + VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); + use_def(body); + + // Sort first by variable typ, then by variable name + std::vector params{use_def.undefined_.begin(), use_def.undefined_.end()}; + std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) { + auto sort_key = [](const Var& var) { + return std::tuple{ + !var->dtype.is_handle(), + var->name_hint, + }; + }; + return sort_key(a) < sort_key(b); + }); + return params; + }(); + + GlobalVar kernel_symbol_global = [&]() { + std::stringstream name; + name << name_prefix_ << "_kernel"; + GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_); + return global_var_supply->FreshGlobal(name.str(), false); + }(); + + PrimFunc device_func(params, body); + device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, + {tir::attr::kNoAlias, Bool(true)}}); + device_func = LaunchParamsAnnotator::Apply(std::move(device_func)); (*device_mod_)->Add(kernel_symbol_global, device_func); + Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); - // generate calls to the device function - Array call_args; - call_args.push_back(StringImm(kernel_symbol_global->name_hint)); - for (PrimExpr arg : arguments) { - call_args.push_back(arg); - } - for (PrimExpr ext : dev_info.thread_extent_) { - call_args.push_back(ext); - } - if (dev_info.use_dyn_shmem_) { - call_args.push_back(dev_info.dyn_shmem_size_); - } - return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); + return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); } // target ir module IRModule* device_mod_; - // Device target - Target device_target_; // function name hint std::string name_prefix_; - // Number of device functions. - int device_func_counter_{0}; - std::unordered_map handle_data_type_; }; -PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod, const GlobalVar& gvar) { - auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; +PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) { + auto opt_target = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target) << "SplitHostDevice: Require the target attribute"; + Target target = opt_target.value(); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); - HostDeviceSplitter splitter(device_mod, target.value(), name_prefix); + HostDeviceSplitter splitter(device_mod, name_prefix); + + auto body = splitter(func->body); - auto* n = func.CopyOnWrite(); - n->body = splitter(std::move(n->body)); - // set the host target to None. - func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); - return std::move(func); + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + func = WithAttr(std::move(func), tvm::tir::attr::kIsHostFunc, Bool(true)); + } + + return func; } namespace transform { Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { - IRModuleNode* mod_ptr = mod.CopyOnWrite(); - auto* func_dict = mod_ptr->functions.CopyOnWrite(); IRModule device_mod = IRModule(Map({})); - - for (auto& kv : *func_dict) { - auto gvar = Downcast(kv.first); - auto& base_func = kv.second; - if (base_func->IsInstance()) { - PrimFunc func = Downcast(std::move(base_func)); - ICHECK(device_mod.defined()) << "The device module must be defined."; - base_func = SplitHostDevice(std::move(func), &device_mod, gvar); + IRModule updates = IRModule(Map({})); + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + PrimFunc func = opt.value(); + func = SplitHostDevice(std::move(func), &device_mod, gvar); + if (!func.same_as(base_func)) { + updates->Add(gvar, func); + } } } + + mod->Update(updates); mod->Update(device_mod); return ConvertSSA()(mod); }; diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 680f23e07a17..2b4a3f51cb3b 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -35,17 +35,26 @@ def test_split_host_device_func_attr(): s[A1].compute_at(s[A2], xo) s[A1].set_scope("shared") - mod = tvm.lower(s, [A, A2], name="f") + mod = tvm.lower(s, [A, A2]) - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") mod = tvm.tir.transform.Apply( lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) )(mod) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] - assert fdevice.attrs["global_symbol"] == "test_kernel0" + mod = tvm.ir.transform.Sequential( + [ + tvm.tir.transform.AnnotateDeviceRegions(), + tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.LowerDeviceKernelLaunch(), + ] + )(mod) + + fdevice = mod["test_kernel"] + + assert fdevice.attrs["global_symbol"] == "test_kernel" assert fdevice.attrs["calling_conv"].value == 2 - assert fdevice.attrs["target"] == cuda_target + assert str(fdevice.attrs["target"]) == str(tvm.target.Target("cuda")) assert fdevice.attrs["tir.is_global_func"].value @@ -60,15 +69,21 @@ def test_ssa_across_entire_module(): class before: @T.prim_func def main(): - T.func_attr({"global_symbol": "main", "target": T.target("cuda")}) + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) for i in range(16): T.attr(0, "device_scope", 0) for j in range(16): T.evaluate(i) - after = tvm.tir.transform.SplitHostDevice()(before) + after = tvm.ir.transform.Sequential( + [ + tvm.tir.transform.AnnotateDeviceRegions(), + tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.LowerDeviceKernelLaunch(), + ] + )(before) loop_var = after["main"].body.loop_var - param_var = after["main_kernel0"].params[0] + param_var = after["main_kernel"].params[0] assert not loop_var.same_as(param_var) From f973841b2a4ce8137b40d3f1eb59ba4453ea10b5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 11:29:05 -0500 Subject: [PATCH 05/16] Add unit tests specifically for SplitHostDevice behavior --- .../test_tir_transform_split_host_device.py | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 2b4a3f51cb3b..3003cdc2516d 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -88,5 +88,81 @@ def main(): assert not loop_var.same_as(param_var) +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.SplitHostDevice() + + +class TestSplitHostDevice(BaseCompare): + """SplitHostDevice divides a function at the "target" attribute""" + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("llvm")}) + T.attr(T.target("cuda"), "target", 0) + T.evaluate(n) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("llvm"), "tir.is_host_func": True}) + mod.main_kernel(n) + + @T.prim_func + def main_kernel(n: T.int32): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": [], + "tir.noalias": T.bool(True), + } + ) + T.evaluate(n) + + return mod + + +class TestSplitHostDeviceWithHost(BaseCompare): + """Host annotations are preserved""" + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + T.evaluate(n) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + mod.main_kernel(n) + + @T.prim_func + def main_kernel(n: T.int32): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": [], + "tir.noalias": T.bool(True), + } + ) + T.evaluate(n) + + return mod + + if __name__ == "__main__": - test_split_host_device_func_attr() + tvm.testing.main() From 95ae88ea0c4c06dba01eb1f4452536e77aa1bd5b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 11:44:10 -0500 Subject: [PATCH 06/16] Added unit test specifically for AnnotateDeviceRegions --- ...t_tir_transform_annotate_device_regions.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/python/unittest/test_tir_transform_annotate_device_regions.py diff --git a/tests/python/unittest/test_tir_transform_annotate_device_regions.py b/tests/python/unittest/test_tir_transform_annotate_device_regions.py new file mode 100644 index 000000000000..efa43027e9c6 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_annotate_device_regions.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T, ir as I + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.AnnotateDeviceRegions() + + +class TestAnnotateThreadExtent(BaseCompare): + """Annotation inserted at the "thread_extent" attribute""" + + def before(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + def expected(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + +class TestAnnotateDeviceScope(BaseCompare): + """Annotation inserted at the "device_scope" attribute""" + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(0, "device_scope", 0) + A[0] = 0.0 + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + T.attr(0, "device_scope", 0) + A[0] = 0.0 + + +if __name__ == "__main__": + tvm.testing.main() From f3c87a55d40e81cd5c2a91e27093784aa9d6c18f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 12:12:55 -0500 Subject: [PATCH 07/16] Added unit tests for LowerDeviceKernelLaunch --- ...test_tir_transform_device_kernel_launch.py | 230 ++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 tests/python/unittest/test_tir_transform_device_kernel_launch.py diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py b/tests/python/unittest/test_tir_transform_device_kernel_launch.py new file mode 100644 index 000000000000..e0ce8b856c51 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T, ir as I + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.LowerDeviceKernelLaunch() + + +class TestLowerDeviceKernelLaunch(BaseCompare): + """Kernel launch parameters are added at the call site + + The "tir.kernel_launch_params" determines which parameters belong + to the runtime, and which below to the device-side PrimFunc. + Parameters that are required prior to launching a kernel (e.g. the + number of Cuda threads to use) are stored in the + `"tir.kernel_launch_params"` attribute, and are used by the + runtime prior in order to launch the generated kernel. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": [], + "global_symbol": "kernel", + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.call_packed("kernel", A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": [], + "calling_conv": 2, + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + +class TestInternalKernelLaunch(BaseCompare): + """Like TestLowerDeviceKernelLaunch, but the kernel has no global_symbol + + Because the host and kernel will be handled by different code + generators, the device-side kernel must be externally exposed for + use by the host-side wrapper, even if the host-side wrapper does + not directly expose the kernel. Therefore, a "global_symbol" + attribute must be added for the kernel if not already present. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": [], + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.call_packed("kernel", A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": [], + "calling_conv": 2, + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + +class TestCollectLaunchParameter(BaseCompare): + """Kernel launch parameters are added at the call site + + The "tir.kernel_launch_params" determines which parameters belong + to the runtime, and which below to the device-side PrimFunc. + Parameters that are required prior to launching a kernel (e.g. the + number of Cuda threads to use) are stored in the + `"tir.kernel_launch_params"` attribute, and are used by the + runtime prior in order to launch the generated kernel. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": ["threadIdx.x"], + "global_symbol": "kernel", + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.call_packed("kernel", A.data, 16) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "tir.kernel_launch_params": ["threadIdx.x"], + "calling_conv": 2, + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + return mod + + +class TestErrorWhenMissingLaunchParams(BaseCompare): + """Kernel must have tir::attr::kKernelLaunchParams + + The PrimFunc attribute `tir::attr::kKernelLaunchParams` + ("tir.kernel_launch_params") is used to determine the order in + which kernel parameters are provided by the runtime. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + A = T.decl_buffer(16, dtype="float32", data=A_data) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + return mod + + expected = tvm.TVMError + + +if __name__ == "__main__": + tvm.testing.main() From a85607682743f4a81ff456bf617aec084e43ba24 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 12:48:52 -0500 Subject: [PATCH 08/16] Minor cleanup, moved all kernel launch collection into one spot Previously, the SplitHostDevice pass added the `tir::attr::kKernelLaunchParams` attribute, and the LowerDeviceKernelLaunch pass filled in the values for it. This cleanup makes the kernel launch params be the sole responsibility of LowerDeviceKernelLaunch. --- .../transforms/lower_device_kernel_launch.cc | 112 +++++++++++------- src/tir/transforms/split_host_device.cc | 54 --------- ...test_tir_transform_device_kernel_launch.py | 61 ++-------- .../test_tir_transform_split_host_device.py | 2 - 4 files changed, 84 insertions(+), 145 deletions(-) diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 3a089ee57e94..faf13825271f 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -36,27 +36,27 @@ namespace tvm { namespace tir { namespace { -struct DeviceInfo { +struct KernelInfo { + // The device on which the PrimFunc runs Target target; + + // The externally visible symbol which may refer to the PrimFunc + // when launching a device kernel. + String global_symbol; + + // The parameters accepted by the PrimFunc. Used to rewrite + // `launch_args` to be in terms of the calling scope. Array params; - Optional> launch_params; - Map thread_extent; - Optional dyn_shmem_size{NullOpt}; - PrimExpr GetArgument(const String& launch_param) const { - if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { - CHECK(dyn_shmem_size.defined()) - << "Compute kernel requires launch parameter \"" << launch_param - << "\", but PrimFunc did not contain Allocate node with shared dynamic scope."; - return dyn_shmem_size.value(); - } + // The launch parameters that should annotate the PrimFunc, if the + // kernel is ever called from the host. + Array launch_params; - auto extent = thread_extent.Get(launch_param); - CHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param - << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent - << "\" defining this thread extent"; - return extent.value(); - } + // Additional arguments which must be provided to the host-side + // PackedFunc. These may be in terms of the function's parameters + // (e.g. a function that computes the average of `N` elements, and + // which must be launched with `N` CUDA threads). + Array launch_args; }; /*! @@ -64,7 +64,7 @@ struct DeviceInfo { */ class DeviceInfoCollector : public StmtVisitor { public: - static DeviceInfo Collect(const PrimFuncNode* func) { + static KernelInfo Collect(const GlobalVar& gvar, const PrimFuncNode* func) { DeviceInfoCollector collector; collector.info_.target = [&]() -> Target { auto target_attr = func->GetAttr(tvm::attr::kTarget).value(); @@ -77,12 +77,41 @@ class DeviceInfoCollector : public StmtVisitor { } }(); collector.info_.params = func->params; - collector.info_.launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams); + collector(func->body); + + // The dynamic shared memory is required to be the last of the + // kernel launch parameters + if (collector.dyn_shmem_size) { + collector.info_.launch_params.push_back( + tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); + } + + collector.info_.global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + + collector.info_.launch_args = collector.info_.launch_params.Map( + [&](const auto& param) { return collector.GetArgument(param); }); + return collector.info_; } private: + PrimExpr GetArgument(const String& launch_param) const { + if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { + CHECK(dyn_shmem_size.defined()) + << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc did not contain Allocate node with shared dynamic scope."; + return dyn_shmem_size.value(); + } + + auto extent = thread_extent.Get(launch_param); + CHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent + << "\" defining this thread extent"; + return extent.value(); + } + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -91,7 +120,8 @@ class DeviceInfoCollector : public StmtVisitor { // use the first appearance as def. if (!defined_thread.count(iv.get())) { defined_thread.insert(iv.get()); - info_.thread_extent.Set(iv->thread_tag, op->value); + info_.launch_params.push_back(iv->thread_tag); + thread_extent.Set(iv->thread_tag, op->value); } } @@ -101,8 +131,7 @@ class DeviceInfoCollector : public StmtVisitor { void VisitStmt_(const AllocateNode* op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK(!info_.dyn_shmem_size.defined()) - << "Only one dynamic shared memory allocation is allowed."; + ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory allocation is allowed."; ICHECK_GT(op->extents.size(), 0); PrimExpr dyn_size = Integer(1); @@ -111,14 +140,19 @@ class DeviceInfoCollector : public StmtVisitor { } dyn_size *= op->dtype.bytes(); - info_.dyn_shmem_size = dyn_size; + dyn_shmem_size = dyn_size; } StmtVisitor::VisitStmt_(op); } - DeviceInfo info_; + // The collected results + KernelInfo info_; // recording what thread axis have been visited. std::unordered_set defined_thread; + // The extent of each thread + Map thread_extent; + // The amount of dynamic shared memory used + Optional dyn_shmem_size{NullOpt}; }; } // namespace @@ -126,7 +160,7 @@ class DeviceKernelMutator : public StmtExprMutator { public: using Parent = StmtExprMutator; - explicit DeviceKernelMutator(std::unordered_map device_info_map) + explicit DeviceKernelMutator(std::unordered_map device_info_map) : device_info_map_(std::move(device_info_map)) {} PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) { @@ -146,15 +180,13 @@ class DeviceKernelMutator : public StmtExprMutator { PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { if (device_kernel_launch_.count(gvar.get())) { - Map new_attrs; - new_attrs.Set(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)); - new_attrs.Set(tvm::tir::attr::kIsGlobalFunc, Bool(true)); - - if (!func->GetAttr(tvm::attr::kGlobalSymbol)) { - new_attrs.Set(tvm::attr::kGlobalSymbol, gvar->name_hint); - } + const auto& info = device_info_map_.at(gvar.get()); - func = WithAttrs(std::move(func), new_attrs); + func = WithAttrs(std::move(func), + {{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)}, + {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, + {tvm::attr::kGlobalSymbol, info.global_symbol}, + {tvm::tir::attr::kIsGlobalFunc, Bool(true)}}); } return func; @@ -171,7 +203,7 @@ class DeviceKernelMutator : public StmtExprMutator { ICHECK(it != device_info_map_.end()) << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " << gvar->name_hint << " did not appear within the IRModule"; - const DeviceInfo& dev_info = it->second; + const KernelInfo& dev_info = it->second; auto caller_device_type = current_target_.value()->GetTargetDeviceType(); auto callee_device_type = dev_info.target->GetTargetDeviceType(); @@ -204,12 +236,12 @@ class DeviceKernelMutator : public StmtExprMutator { device_kernel_launch_.insert(gvar); Array call_args; - call_args.push_back(StringImm(gvar->name_hint)); + call_args.push_back(StringImm(dev_info.global_symbol)); for (PrimExpr arg : node->args) { call_args.push_back(arg); } - for (const auto& launch_param : dev_info.launch_params.value()) { - call_args.push_back(Substitute(dev_info.GetArgument(launch_param), param_map)); + for (const auto& launch_arg : dev_info.launch_args) { + call_args.push_back(Substitute(launch_arg, param_map)); } auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; @@ -218,7 +250,7 @@ class DeviceKernelMutator : public StmtExprMutator { } Optional current_target_; - std::unordered_map device_info_map_; + std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; }; @@ -227,10 +259,10 @@ namespace transform { Pass LowerDeviceKernelLaunch() { auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { auto mutator = [&mod]() { - std::unordered_map device_info_map; + std::unordered_map device_info_map; for (const auto& [gvar, base_func] : mod->functions) { if (auto* prim_func = base_func.as()) { - device_info_map[gvar.get()] = DeviceInfoCollector::Collect(prim_func); + device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func); } } return DeviceKernelMutator(std::move(device_info_map)); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 89f17b3ca84c..996a33e936ef 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -41,59 +41,6 @@ namespace tvm { namespace tir { -/*! - * \brief Visitor class to collect device-side program information. - */ -class LaunchParamsAnnotator : public StmtVisitor { - public: - static PrimFunc Apply(PrimFunc func) { - LaunchParamsAnnotator collector; - collector(func->body); - return WithAttr(std::move(func), tir::attr::kKernelLaunchParams, collector.GetLaunchParams()); - } - - Array GetLaunchParams() const { - Array launch_params = threads_; - - if (uses_dyn_shmem_) { - launch_params.push_back(tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); - } - return launch_params; - } - - private: - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); - // thread_extent can appear multiple times - // use the first appearance as def. - if (!defined_thread.count(iv.get())) { - defined_thread.insert(iv.get()); - threads_.push_back(iv->thread_tag); - } - } - - StmtVisitor::VisitStmt_(op); - } - - void VisitStmt_(const AllocateNode* op) final { - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK(!uses_dyn_shmem_) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); - - uses_dyn_shmem_ = true; - } - StmtVisitor::VisitStmt_(op); - } - - Array threads_; - bool uses_dyn_shmem_{false}; - // recording what thread axis have been visited. - std::unordered_set defined_thread; -}; - class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix) @@ -137,7 +84,6 @@ class HostDeviceSplitter : public StmtMutator { PrimFunc device_func(params, body); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}}); - device_func = LaunchParamsAnnotator::Apply(std::move(device_func)); (*device_mod_)->Add(kernel_symbol_global, device_func); Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py b/tests/python/unittest/test_tir_transform_device_kernel_launch.py index e0ce8b856c51..5ad7c975b84b 100644 --- a/tests/python/unittest/test_tir_transform_device_kernel_launch.py +++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py @@ -45,13 +45,7 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def kernel(A_data: T.handle("float32")): - T.func_attr( - { - "target": T.target("cuda"), - "tir.kernel_launch_params": [], - "global_symbol": "kernel", - } - ) + T.func_attr({"target": T.target("cuda")}) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 @@ -70,8 +64,8 @@ def kernel(A_data: T.handle("float32")): T.func_attr( { "target": T.target("cuda"), - "tir.kernel_launch_params": [], "calling_conv": 2, + "tir.kernel_launch_params": [], "global_symbol": "kernel", "tir.is_global_func": True, } @@ -82,14 +76,17 @@ def kernel(A_data: T.handle("float32")): return mod -class TestInternalKernelLaunch(BaseCompare): - """Like TestLowerDeviceKernelLaunch, but the kernel has no global_symbol +class TestExternallyVisibleKernelLaunch(BaseCompare): + """Like TestLowerDeviceKernelLaunch, with pre-defined global_symbol Because the host and kernel will be handled by different code generators, the device-side kernel must be externally exposed for use by the host-side wrapper, even if the host-side wrapper does not directly expose the kernel. Therefore, a "global_symbol" attribute must be added for the kernel if not already present. + + If the kernel already has a specific name, that name should be + preserved. """ def before(self): @@ -102,12 +99,7 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def kernel(A_data: T.handle("float32")): - T.func_attr( - { - "target": T.target("cuda"), - "tir.kernel_launch_params": [], - } - ) + T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel_by_another_name"}) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 @@ -119,16 +111,16 @@ class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) - T.call_packed("kernel", A.data) + T.call_packed("kernel_by_another_name", A.data) @T.prim_func def kernel(A_data: T.handle("float32")): T.func_attr( { "target": T.target("cuda"), - "tir.kernel_launch_params": [], "calling_conv": 2, - "global_symbol": "kernel", + "tir.kernel_launch_params": [], + "global_symbol": "kernel_by_another_name", "tir.is_global_func": True, } ) @@ -162,7 +154,6 @@ def kernel(A_data: T.handle("float32")): T.func_attr( { "target": T.target("cuda"), - "tir.kernel_launch_params": ["threadIdx.x"], "global_symbol": "kernel", } ) @@ -185,8 +176,8 @@ def kernel(A_data: T.handle("float32")): T.func_attr( { "target": T.target("cuda"), - "tir.kernel_launch_params": ["threadIdx.x"], "calling_conv": 2, + "tir.kernel_launch_params": ["threadIdx.x"], "global_symbol": "kernel", "tir.is_global_func": True, } @@ -198,33 +189,5 @@ def kernel(A_data: T.handle("float32")): return mod -class TestErrorWhenMissingLaunchParams(BaseCompare): - """Kernel must have tir::attr::kKernelLaunchParams - - The PrimFunc attribute `tir::attr::kKernelLaunchParams` - ("tir.kernel_launch_params") is used to determine the order in - which kernel parameters are provided by the runtime. - """ - - def before(self): - @I.ir_module - class mod: - @T.prim_func - def main(A: T.Buffer(16, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) - mod.kernel(A.data) - - @T.prim_func - def kernel(A_data: T.handle("float32")): - T.func_attr({"target": T.target("cuda")}) - A = T.decl_buffer(16, dtype="float32", data=A_data) - i = T.launch_thread("threadIdx.x", 16) - A[i] = 0.0 - - return mod - - expected = tvm.TVMError - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 3003cdc2516d..a3adfb47dc21 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -119,7 +119,6 @@ def main_kernel(n: T.int32): T.func_attr( { "target": T.target("cuda"), - "tir.kernel_launch_params": [], "tir.noalias": T.bool(True), } ) @@ -155,7 +154,6 @@ def main_kernel(n: T.int32): T.func_attr( { "target": T.target("cuda"), - "tir.kernel_launch_params": [], "tir.noalias": T.bool(True), } ) From 455c7a87918600ed824b3fa8be79c287e764b0a3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 23 May 2023 08:25:54 -0500 Subject: [PATCH 09/16] Updated unit tests for LowerWarpMemory --- .../test_tir_transform_lower_warp_memory.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index d4abc26bb204..c7e90d4e7dc9 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -22,6 +22,16 @@ from tvm.contrib.nvcc import have_fp16 +def _run_passes(mod): + cuda_target = tvm.target.Target("cuda", host="llvm") + assert cuda_target.thread_warp_size == 32 + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) + mod = tvm.tir.transform.SplitHostDevice()(mod) + mod = tvm.tir.transform.LowerWarpMemory()(mod) + return mod + + @tvm.testing.requires_cuda def test_lower_warp_memory_local_scope(): m = 128 @@ -39,16 +49,12 @@ def test_lower_warp_memory_local_scope(): xo, xi = s[AA].split(s[AA].op.axis[0], 32) s[AA].bind(xi, tx) - cuda_target = tvm.target.Target("cuda") - assert cuda_target.thread_warp_size == 32 # lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(s, [A, B], name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + mod = _run_passes(mod) + fdevice = mod["f_kernel"] allocate = fdevice.body.body assert allocate.buffer_var.type_annotation.storage_scope == "local" assert fdevice.body.body.extents[0].value == 2 @@ -103,7 +109,7 @@ def check_cuda(dtype): A = te.placeholder((m,), name="A", dtype=dtype) B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name="B") - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") assert cuda_target.thread_warp_size == 32 with cuda_target: s = te.create_schedule(B.op) @@ -168,7 +174,7 @@ def check_cuda(dtype): name="B", ) - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") assert cuda_target.thread_warp_size == 2 * m with cuda_target: s = te.create_schedule(B.op) @@ -214,7 +220,7 @@ def check_cuda(dtype): B = te.placeholder((m,), name="B", dtype=dtype) C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name="C") - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") assert m <= cuda_target.thread_warp_size with cuda_target: s = te.create_schedule(C.op) @@ -310,15 +316,12 @@ def test_lower_warp_memory_same_thread(): xo, xi = s[BB].split(s[BB].op.axis[0], factor=32) s[BB].bind(xi, tx) - cuda_target = tvm.target.Target("cuda") - assert cuda_target.thread_warp_size == 32 # lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(s, [A, B], name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + + mod = _run_passes(mod) + fdevice = mod["f_kernel"] assert "tvm_warp_shuffle" not in fdevice.script() @@ -338,13 +341,11 @@ def test_lower_warp_memory_divide_by_factor(): stmt = ib.get() func = tvm.tir.PrimFunc([], stmt) func = func.with_attr("from_legacy_te_schedule", True) - cuda_target = tvm.target.Target("cuda") # lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(func, name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) with pytest.raises(tvm.error.TVMError, match="Divide by zero") as cm: - tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + _run_passes(mod) if __name__ == "__main__": From fcdfaf2681daff016e7f8df1ec94a9ac76a29a26 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 23 May 2023 08:26:31 -0500 Subject: [PATCH 10/16] Updated unit tests for ThreadSync --- tests/python/unittest/test_tir_transform_thread_sync.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index eb578a8817b5..57ea223cf984 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -24,12 +24,13 @@ def run_passes(func: tvm.tir.PrimFunc): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") mod = tvm.tir.transform.Apply( lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) )(mod) + mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) mod = tvm.tir.transform.SplitHostDevice()(mod) return tvm.tir.transform.ThreadSync("shared")(mod) @@ -55,7 +56,7 @@ def test_thread_storage_sync(): func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) mod = run_passes(func) - f = mod["test_kernel0"] + f = mod["test_kernel"] body_list = tvm.tir.stmt_list(f.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) From 36b5657e2ef084db07685408a34944e8fd23d1da Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 23 May 2023 08:27:15 -0500 Subject: [PATCH 11/16] Updated unit test for inject ptx async copy --- .../python/unittest/test_tir_transform_inject_ptx_async_copy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 168f8c879bcd..a48a8ea236ec 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -201,7 +201,7 @@ def test_inject_async_copy_shared_dyn(): #define int64_t long long #define uint64_t unsigned long long #endif -extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { +extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64]; A_shared[((int)threadIdx.x)] = 0.000000e+00f; From 65ac46a66cd3f350db4dae04d0e8c93431266ed4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 09:36:51 -0500 Subject: [PATCH 12/16] [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs https://github.com/apache/tvm/pull/14913 and https://github.com/apache/tvm/pull/14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. --- src/tir/transforms/make_packed_api.cc | 3 +++ src/tir/transforms/make_unpacked_api.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. From 96dc763580d0ab1d3b89c5e7686876c85aedfdc5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 11:04:43 -0500 Subject: [PATCH 13/16] Maintain "tir.is_global_func" attr in device-side entry point --- src/tir/transforms/split_host_device.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 996a33e936ef..de313f62eab9 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -83,7 +83,8 @@ class HostDeviceSplitter : public StmtMutator { PrimFunc device_func(params, body); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, - {tir::attr::kNoAlias, Bool(true)}}); + {tir::attr::kNoAlias, Bool(true)}, + {tir::attr::kIsGlobalFunc, Bool(true)}}); (*device_mod_)->Add(kernel_symbol_global, device_func); Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); From 2c7c605c01f817d9025fed386e7ff30dfc3db4c5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 14:52:42 -0500 Subject: [PATCH 14/16] SplitHostDevice, update the host-side target to be the host --- src/tir/transforms/split_host_device.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index de313f62eab9..9270b356ba22 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -112,7 +112,8 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g if (!body.same_as(func->body)) { func.CopyOnWrite()->body = body; - func = WithAttr(std::move(func), tvm::tir::attr::kIsHostFunc, Bool(true)); + auto target_host = target->GetHost().value_or(Target("llvm")); + func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); } return func; From eb87360388705dd82893cc5afa1b4b7f5f01018d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 May 2023 09:59:29 -0500 Subject: [PATCH 15/16] [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on https://github.com/apache/tvm/pull/14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`. --- .../transforms/lower_device_kernel_launch.cc | 17 ++++------------- .../test_tir_transform_device_kernel_launch.py | 12 ++++++------ 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index faf13825271f..5ffbf0d7a7fd 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -64,18 +64,9 @@ struct KernelInfo { */ class DeviceInfoCollector : public StmtVisitor { public: - static KernelInfo Collect(const GlobalVar& gvar, const PrimFuncNode* func) { + static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) { DeviceInfoCollector collector; - collector.info_.target = [&]() -> Target { - auto target_attr = func->GetAttr(tvm::attr::kTarget).value(); - bool is_host_func = - func->GetAttr(tvm::tir::attr::kIsHostFunc).value_or(Bool(false))->value; - if (is_host_func) { - return target_attr->GetHost().value(); - } else { - return target_attr.WithoutHost(); - } - }(); + collector.info_.target = func->GetAttr(tvm::attr::kTarget).value().WithoutHost(); collector.info_.params = func->params; collector(func->body); @@ -261,8 +252,8 @@ Pass LowerDeviceKernelLaunch() { auto mutator = [&mod]() { std::unordered_map device_info_map; for (const auto& [gvar, base_func] : mod->functions) { - if (auto* prim_func = base_func.as()) { - device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func); + if (auto prim_func = base_func.as()) { + device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func.value()); } } return DeviceKernelMutator(std::move(device_info_map)); diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py b/tests/python/unittest/test_tir_transform_device_kernel_launch.py index 5ad7c975b84b..a0f77da3766b 100644 --- a/tests/python/unittest/test_tir_transform_device_kernel_launch.py +++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py @@ -40,7 +40,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) mod.kernel(A.data) @T.prim_func @@ -56,7 +56,7 @@ def expected(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) T.call_packed("kernel", A.data) @T.prim_func @@ -94,7 +94,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) mod.kernel(A.data) @T.prim_func @@ -110,7 +110,7 @@ def expected(self): class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) T.call_packed("kernel_by_another_name", A.data) @T.prim_func @@ -146,7 +146,7 @@ def before(self): class mod: @T.prim_func def main(A: T.Buffer(16, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) mod.kernel(A.data) @T.prim_func @@ -168,7 +168,7 @@ def expected(self): class mod: @T.prim_func def main(A: T.Buffer(16, "float32")): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) T.call_packed("kernel", A.data, 16) @T.prim_func From 30d28213c782af09e5a6e532c29d78a1fc338382 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 12:35:30 -0500 Subject: [PATCH 16/16] Remove is_host_func from SplitHostDevice tests --- .../test_tir_transform_split_host_device.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index a3adfb47dc21..cf866ae005c8 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -100,7 +100,7 @@ def before(self): class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("llvm")}) + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) T.attr(T.target("cuda"), "target", 0) T.evaluate(n) @@ -111,7 +111,7 @@ def expected(self): class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm -opt-level=0")}) mod.main_kernel(n) @T.prim_func @@ -120,6 +120,7 @@ def main_kernel(n: T.int32): { "target": T.target("cuda"), "tir.noalias": T.bool(True), + "tir.is_global_func": True, } ) T.evaluate(n) @@ -127,15 +128,19 @@ def main_kernel(n: T.int32): return mod -class TestSplitHostDeviceWithHost(BaseCompare): - """Host annotations are preserved""" +class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare): + """Like TestSplitHostDevice, but no host specified in the host's target + + The `T.attr` specifying the device still requires splitting out + the kernel. + """ def before(self): @I.ir_module class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("cuda", host="llvm")}) + T.func_attr({"target": T.target("llvm")}) T.attr(T.target("cuda"), "target", 0) T.evaluate(n) @@ -146,7 +151,7 @@ def expected(self): class mod: @T.prim_func def main(n: T.int32): - T.func_attr({"target": T.target("cuda", host="llvm"), "tir.is_host_func": True}) + T.func_attr({"target": T.target("llvm")}) mod.main_kernel(n) @T.prim_func @@ -155,6 +160,7 @@ def main_kernel(n: T.int32): { "target": T.target("cuda"), "tir.noalias": T.bool(True), + "tir.is_global_func": True, } ) T.evaluate(n)