diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index ec8e32526abb..5bac25faa5fb 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -94,7 +94,7 @@ TVM_DLL double EstimateTIRFlops(const IRModule& mod); /*! * \brief Find undefined vars in the statement. - * \param stmt The function to be checked. + * \param stmt The statement to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ @@ -107,6 +107,14 @@ TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); */ TVM_DLL Array UndefinedVars(const PrimExpr& expr); +/*! + * \brief Find undefined vars in the expression. + * \param expr The expression to be checked. + * \param defs The vars that is defined. + * \return Array of undefined vars. + */ +TVM_DLL Array UndefinedVars(const PrimExpr& expr, const Array& defs); + /*! * \brief Analyze the side effect * \param expr The expression to be checked. diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc new file mode 100644 index 000000000000..7ef8e532a396 --- /dev/null +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file var_use_def_analysis.cc + * \brief Classes and functions to analyze var defition and usage. + */ +#include "var_use_def_analysis.h" +namespace tvm { +namespace tir { + +VarUseDefAnalyzer::VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent) + : visit_thread_extent_(visit_thread_extent) { + for (const Var v : defined_vars) { + use_count_[v.get()] = 0; + } +} + +void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!use_count_.count(iv->var.get())) { + this->HandleDef(iv->var.get()); + } + + if (visit_thread_extent_) { + this->VisitExpr(op->value); + } + + this->VisitStmt(op->body); + } else { + StmtExprVisitor::VisitStmt_(op); + } +} + +void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { + this->HandleDef(op->var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) { + this->HandleDef(op->loop_var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) { + this->HandleDef(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) { + this->HandleDef(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to construct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var.get()); + this->VisitExpr(op->value); + if (it != let_binding_.end()) { + ICHECK(deep_equal_(it->second->value, op->value)) + << "Let cannot bind the same var to two different values"; + } else { + this->HandleDef(op->var.get()); + let_binding_[op->var.get()] = op; + } + this->VisitExpr(op->body); +} + +void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { + this->HandleUse(op); + StmtExprVisitor::VisitExpr_(op); +} + +void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) { + for (const auto& iv : op->axis) { + this->HandleDef(iv->var.get()); + } + StmtExprVisitor::VisitExpr_(op); +} + +void VarUseDefAnalyzer::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); +} + +void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) { + this->HandleUse(buffer->data.get()); + auto visit_arr = [&](Array arr) { + for (const auto& element : arr) { + this->VisitExpr(element); + } + }; + + visit_arr(buffer->shape); + visit_arr(buffer->strides); +} + +void VarUseDefAnalyzer::HandleDef(const VarNode* v) { + ICHECK(!def_count_.count(v)) << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; + ICHECK(!use_count_.count(v)) << "variable " << v->name_hint + << " has been used before definition!"; + use_count_[v] = 0; + def_count_[v] = 1; +} + +void VarUseDefAnalyzer::HandleUse(const VarNode* v) { + auto it = use_count_.find(v); + if (it != use_count_.end()) { + if (it->second >= 0) { + ++it->second; + } + } else { + undefined_.push_back(GetRef(v)); + use_count_[v] = -1; + } +} + +Array UndefinedVars(const Stmt& stmt, const Array& args) { + VarUseDefAnalyzer m(args); + m(stmt); + return m.undefined_; +} + +Array UndefinedVars(const PrimExpr& expr) { + VarUseDefAnalyzer m({}); + m(expr); + return m.undefined_; +} + +Array UndefinedVars(const PrimExpr& expr, const Array& args) { + VarUseDefAnalyzer m(args); + m(expr); + return m.undefined_; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h new file mode 100644 index 000000000000..ad275011d90c --- /dev/null +++ b/src/tir/analysis/var_use_def_analysis.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/src/tir/analysis/var_use_def_analyzer.h + * \brief Variable definition and usage analysis class. + */ +#ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ +#define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ + +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. + * \param defined_vars Variables that have been defined. + * \param visit_thread_extent Whether enters thread extent expressions or not. + * \sa UndefinedVars + */ +class VarUseDefAnalyzer : public StmtExprVisitor { + public: + explicit VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent = true); + // The fields are publically readible to + // be accessible to the users. + bool visit_thread_extent_{true}; + Array undefined_; + + std::unordered_map use_count_; + std::unordered_map def_count_; + + private: + ExprDeepEqual deep_equal_; + std::unordered_map let_binding_; + void VisitStmt_(const AttrStmtNode* op) final; + + void VisitStmt_(const LetStmtNode* op) final; + + void VisitStmt_(const ForNode* op) final; + + void VisitStmt_(const AllocateNode* op) final; + + void VisitStmt_(const AllocateConstNode* op) final; + + void VisitStmt_(const StoreNode* op) final; + + void VisitStmt_(const BufferStoreNode* op) final; + + void VisitExpr_(const LetNode* op) final; + + void VisitExpr_(const VarNode* op) final; + + void VisitExpr_(const ReduceNode* op) final; + + void VisitExpr_(const LoadNode* op) final; + + void VisitExpr_(const BufferLoadNode* op) final; + + void HandleDef(const VarNode* v); + + void HandleUse(const VarNode* v); + + void VisitBuffer(Buffer buffer); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de7d38d7d57..4f411228d262 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -35,64 +35,43 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../analysis/var_use_def_analysis.h" #include "ir_utils.h" namespace tvm { namespace tir { -// use/def analysis, also delete unreferenced lets -class VarUseDefAnalysis : public StmtExprMutator { +/*! + * \brief Visitor class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { public: - Stmt VisitStmt_(const AttrStmtNode* op) final { + Array thread_axis_; + Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; + + private: + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. - if (!use_count_.count(iv->var.get())) { - this->HandleDef(iv->var.get()); + if (!defined_thread.count(iv.get())) { + defined_thread.insert(iv.get()); thread_axis_.push_back(iv); thread_extent_.push_back(op->value); } - PrimExpr value = op->value; - if (visit_thread_extent_) { - value = this->VisitExpr(value); - } - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); - } - return AttrStmt(op->node, op->attr_key, value, body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - Stmt VisitStmt_(const LetStmtNode* op) final { - this->HandleDef(op->var.get()); - Stmt body = this->VisitStmt(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { - return body; + this->VisitExpr(op->value); + this->VisitStmt(op->body); } else { - PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return LetStmt(op->var, value, body); - } + StmtVisitor::VisitStmt_(op); } } - Stmt VisitStmt_(const ForNode* op) final { - this->HandleDef(op->loop_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateNode* op) final { - this->HandleDef(op->buffer_var.get()); + void VisitStmt_(const AllocateNode* op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; @@ -104,44 +83,42 @@ class VarUseDefAnalysis : public StmtExprMutator { dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); use_dyn_shmem_ = true; } - return StmtExprMutator::VisitStmt_(op); + StmtVisitor::VisitStmt_(op); } - Stmt VisitStmt_(const AllocateConstNode* op) final { - this->HandleDef(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } + // recording what thread axis have been visited. + std::unordered_set defined_thread; +}; - Stmt VisitStmt_(const StoreNode* op) final { - LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; - } +/*! + * \brief Mutator class to remove unrefenced let stmt/expressions. + * \param use_count The pre-computed variable to use count map. + */ +class UnreferencedLetRemover : public StmtExprMutator { + public: + explicit UnreferencedLetRemover(const std::unordered_map& use_count) + : use_count_(use_count) {} - Stmt VisitStmt_(const BufferStoreNode* op) final { - VisitBuffer(op->buffer); - return StmtExprMutator::VisitStmt_(op); + private: + Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt body = this->VisitStmt(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { + return body; + } else { + PrimExpr value = this->VisitExpr(op->value); + if (body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } } PrimExpr VisitExpr_(const LetNode* op) final { - // Weaker SSA condition - // A single var can be binded in multiple lets - // but they have to bind to the same value. - // This is used to allow cases when we reuse a single let - // expression to construct a nested expr. - // (let x = 1 in x + 1) * (let x = 1 in x + 1) - auto it = let_binding_.find(op->var); - PrimExpr value = this->VisitExpr(op->value); - if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, value)) - << "Let cannot bind the same var to two different values"; - return GetRef(it->second); - } else { - this->HandleDef(op->var.get()); - let_binding_[op->var] = op; - } PrimExpr body = this->VisitExpr(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { + PrimExpr value = this->VisitExpr(op->value); + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { return body; } else { if (body.same_as(op->body) && value.same_as(op->value)) { @@ -152,96 +129,10 @@ class VarUseDefAnalysis : public StmtExprMutator { } } - PrimExpr VisitExpr_(const VarNode* op) final { - this->HandleUse(GetRef(op)); - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const ReduceNode* op) final { - for (const auto& iv : op->axis) { - this->HandleDef(iv->var.get()); - } - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - VisitBuffer(op->buffer); - return StmtExprMutator::VisitExpr_(op); - } - - void VisitBuffer(Buffer buffer) { - this->HandleUse(buffer->data); - auto visit_arr = [&](Array arr) { - for (const auto& element : arr) { - this->VisitExpr(element); - } - }; - - visit_arr(buffer->shape); - visit_arr(buffer->strides); - } - - void HandleDef(const VarNode* v) { - ICHECK(!def_count_.count(v)) << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - ICHECK(!use_count_.count(v)) << "variable " << v->name_hint - << " has been used before definition!"; - use_count_[v] = 0; - def_count_[v] = 1; - } - - void HandleUse(const PrimExpr& v) { - ICHECK(v.as()); - Var var = Downcast(v); - auto it = use_count_.find(var.get()); - if (it != use_count_.end()) { - if (it->second >= 0) { - ++it->second; - } - } else { - undefined_.push_back(var); - use_count_[var.get()] = -1; - } - } - - // The fields are publically readible to - // be accessible to the users. - bool visit_thread_extent_{true}; - bool simplify_let_{true}; - Array undefined_; - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; - std::unordered_map use_count_; - std::unordered_map def_count_; - - private: - ExprDeepEqual deep_equal_; - std::unordered_map let_binding_; + // pre-computed variable to use count map. + const std::unordered_map& use_count_; }; -Array UndefinedVars(const Stmt& stmt, const Array& args) { - VarUseDefAnalysis m; - m.simplify_let_ = false; - for (Var arg : args) { - m.use_count_[arg.get()] = 0; - } - m(stmt); - return m.undefined_; -} - -Array UndefinedVars(const PrimExpr& expr) { - VarUseDefAnalysis m; - m.simplify_let_ = false; - m(expr); - return m.undefined_; -} - class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) @@ -266,16 +157,19 @@ class HostDeviceSplitter : public StmtMutator { os << name_prefix_ << "_kernel" << device_func_counter_++; std::string kernel_symbol = os.str(); // isolate the device function. - VarUseDefAnalysis m; - m.visit_thread_extent_ = false; - body = m(std::move(body)); + VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); + use_def(body); + DeviceInfoCollector dev_info; + dev_info(body); + UnreferencedLetRemover let_remover(use_def.use_count_); + body = let_remover(std::move(body)); Array params; Array arguments; Map remap_vars; // Strictly order the arguments: Var pointers, positional arguments. - for (Var var : m.undefined_) { + for (Var var : use_def.undefined_) { if (var.dtype().is_handle()) { // Create a new version of v. auto it = handle_data_type_.find(var.get()); @@ -295,7 +189,7 @@ class HostDeviceSplitter : public StmtMutator { } } // positional arguments - for (Var var : m.undefined_) { + for (Var var : use_def.undefined_) { if (!var.dtype().is_handle()) { params.push_back(var); arguments.push_back(var); @@ -305,7 +199,8 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, @@ -313,7 +208,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - if (m.use_dyn_shmem_) { + if (dev_info.use_dyn_shmem_) { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); } @@ -325,11 +220,11 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr arg : arguments) { call_args.push_back(arg); } - for (PrimExpr ext : m.thread_extent_) { + for (PrimExpr ext : dev_info.thread_extent_) { call_args.push_back(ext); } - if (m.use_dyn_shmem_) { - call_args.push_back(m.dyn_shmem_size_); + if (dev_info.use_dyn_shmem_) { + call_args.push_back(dev_info.dyn_shmem_size_); } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); }