From ae95d14ff446fda928626f8faf1f707972b174e0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Dec 2023 15:07:03 -0600 Subject: [PATCH 01/10] [TIR] In SplitHostDevice, check for variables in thread extents Otherwise, they would be undefined after being de-duplicated by `ConvertSSA`. --- src/tir/analysis/verify_well_formed.cc | 147 ++++++ src/tir/ir/tir_visitor_with_path.cc | 431 ++++++++++++++++++ src/tir/ir/tir_visitor_with_path.h | 209 +++++++++ src/tir/transforms/ir_utils.cc | 37 +- src/tir/transforms/split_host_device.cc | 2 +- .../test_tir_analysis_verify_well_formed.py | 71 ++- .../test_tir_transform_split_host_device.py | 39 ++ 7 files changed, 933 insertions(+), 3 deletions(-) create mode 100644 src/tir/ir/tir_visitor_with_path.cc create mode 100644 src/tir/ir/tir_visitor_with_path.h diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 898183533ccd..6adebdcc282c 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -26,12 +26,96 @@ #include #include +#include +#include +#include + #include "../ir/functor_common.h" +#include "../ir/tir_visitor_with_path.h" #include "tvm/ir/module.h" namespace tvm { namespace tir { +namespace { + +template +class Verifier : protected TIRVisitorWithPath { + public: + template + static bool Verify(const TirNodeRef& node, bool assert_on_error) { + DerivedVerifier verifier(assert_on_error); + verifier(node); + return !verifier.has_error_; + } + + protected: + Verifier(bool assert_on_error) : assert_on_error_(assert_on_error) {} + + /* \brief Helper class to handle the bool-or-assert handles + * + * Each verifier can either return a boolean, or assert on failure. + * To avoid needing to duplicate this logic at every step, the + * Verify() method can be used. Similar to `LOG(FATAL)` or + * `LOG(DEBUG)`, it returns an object that can accept streamed + * context information. + * + * If the error should be raised, then the context is collected + * identically to `LOG(FATAL)`. If a boolean is returned, or if the + * condition passes, then the streamed context is discarded. + * + * Usage: + * + * Verify(value == expected_value) + * << "ValueError: " << value + * << " was not the expected value of " << expected_value; + */ + class VerifyStream { + public: + VerifyStream(bool log_fatal) { + if (log_fatal) { + log_.emplace(); + } + } + + VerifyStream(const VerifyStream&) = delete; + VerifyStream& operator=(const VerifyStream&) = delete; + VerifyStream(VerifyStream&& other) { std::swap(log_, other.log_); } + VerifyStream& operator=(VerifyStream&& other) { + std::swap(log_, other.log_); + return *this; + } + + template + VerifyStream& operator<<(T&& t) { + if (log_.has_value()) { + log_.value() << std::forward(t); + } + return *this; + } + + ~VerifyStream() noexcept(false) { + if (log_.has_value()) { + LOG(FATAL) << log_->str(); + } + } + + std::optional log_{std::nullopt}; + }; + + // TODO(Lunderberg): Add the filename/linenum with + // std::source_location when C++20 is available. + VerifyStream Verify(bool condition) { + has_error_ = has_error_ || !condition; + return VerifyStream(!condition && assert_on_error_); + } + + bool assert_on_error_; + bool has_error_{false}; +}; + +} // namespace + /*! \brief Verify all Expr inside the block does not contain: * 1. loop vars outside the current block. * 2. block vars of parent blocks. @@ -135,10 +219,70 @@ class BlockVarAccessVerifier : public StmtExprVisitor { bool has_error_{false}; }; +class UndefinedVarVerifier : public Verifier { + public: + // Until templated-this arrives in C++23, the CRTP can't inject a + // constructor into the child class. Therefore, must explicitly add + // the constructor. + using Verifier::Verifier; + + private: + void EnterDef(const Var& var, ObjectPath path) override { + { + auto it = currently_defined_.find(var); + Verify(it == currently_defined_.end()) + << "ValueError: " + << "TIR is ill-formed, " + << "due to multiple nested definitions of variable " << var + << ". It was first defined at " << it->second << ", and was re-defined at " << path; + } + + { + auto it = previously_defined_.find(var); + Verify(it == previously_defined_.end()) + << "ValueError: " + << "TIR is ill-formed, " + << "due to multiple definitions of variable " << var << ". It was first defined at " + << it->second << ", and was later re-defined at " << path; + } + + currently_defined_.insert({var, path}); + } + + void ExitDef(const Var& var, ObjectPath path) override { + auto active_def = currently_defined_.find(var); + + currently_defined_.erase(active_def); + previously_defined_.insert({var, path}); + } + + void VisitExpr_(const VarNode* op, ObjectPath path) override { + auto var = GetRef(op); + + auto prev_def = previously_defined_.find(var); + Verify(prev_def == previously_defined_.end()) + << "ValueError: " + << "Invalid use of variable " << var << " at " << path << ". " + << "While this variable was previously defined at " << prev_def->second + << ", this definition is no longer in-scope."; + + auto active_def = currently_defined_.find(var); + Verify(active_def != currently_defined_.end()) + << "ValueError: " + << "Invalid use of undefined variable " << var << " at " << path; + } + + std::unordered_map currently_defined_; + std::unordered_map previously_defined_; +}; + bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { if (!BlockVarAccessVerifier::Verify(func, assert_mode)) { return false; } + + if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false; + // TODO(Siyuan): add more checks here. return true; } @@ -152,6 +296,9 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { } } } + + if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false; + return true; } diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc new file mode 100644 index 000000000000..93e034b5d340 --- /dev/null +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/tir_visitor_with_path.cc + * \brief Provide a TIR visitor that tracks the current location + */ + +#include "tir_visitor_with_path.h" + +#include +#include + +namespace tvm { +namespace tir { + +void TIRVisitorWithPath::Visit(const IRModule& mod, ObjectPath path) { + // To ensure deterministic order of visits, sort the GlobalVar first + // by visibility (public then private), then alphabetically by name. + std::vector gvars; + std::unordered_set externally_exposed; + for (const auto& [gvar, func] : mod->functions) { + gvars.push_back(gvar); + if (func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { + externally_exposed.insert(gvar); + } + } + + std::sort(gvars.begin(), gvars.end(), + [&externally_exposed](const GlobalVar& a, const GlobalVar& b) { + bool a_exposed = externally_exposed.count(a); + bool b_exposed = externally_exposed.count(b); + if (a_exposed != b_exposed) { + return a < b; + } else { + return a->name_hint < b->name_hint; + } + }); + + std::vector> context; + + for (const auto& gvar : gvars) { + context.push_back(WithDef(gvar, path->Attr("global_var_map_")->MapValue(gvar->name_hint))); + } + + for (const auto& gvar : gvars) { + auto base_func = mod->functions[gvar]; + if (auto prim_func = base_func.as()) { + Visit(prim_func.value(), path->Attr("functions")->MapValue(gvar)); + } + } + + while (context.size()) context.pop_back(); +} + +void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { + // The implicit definitions from a PrimFunc::buffer_map are pretty + // weird. They only apply if no previous definition of that + // variable has occurred. Therefore, to ensure that we only avoid + // duplicate calls to VisitVarDef, these semantics need to be + // checked. + std::unordered_set defined_params; + std::vector, DefContext>> context; + + auto ppath = path->Attr("params"); + for (size_t i = 0; i < func->params.size(); i++) { + context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i))); + defined_params.insert(func->params[i]); + } + + auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr, + ObjectPath path) { + if (auto opt = expr.as()) { + auto var = opt.value(); + if (!defined_params.count(var)) { + context.push_back(WithDef(var, path)); + defined_params.insert(var); + } + } + }; + auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array& arr, + ObjectPath path) { + for (size_t i = 0; i < arr.size(); i++) { + try_visit_implicit_var_def(arr[i], path->ArrayIndex(i)); + } + }; + + auto buffer_map_path = path->Attr("buffer_map"); + for (size_t i = 0; i < func->params.size(); i++) { + if (auto opt = func->buffer_map.Get(func->params[i])) { + auto buf = opt.value(); + auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); + + // A buffer in the buffer_map always defines its data pointer + context.push_back(WithDef(buf->data, buf_path->Attr("data"))); + + // But other implicit definitions only apply if they weren't + // provided as explicit parameters, and they weren't defined + // implicitly by any previous buffer. + try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape")); + try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides")); + try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset")); + } + } + + // Only after all the implicit definitions have been visited can we + // visit the buffer definition itself. + for (size_t i = 0; i < func->params.size(); i++) { + if (auto opt = func->buffer_map.Get(func->params[i])) { + auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); + EnterDef(opt.value(), buf_path); + } + } + + Visit(func->body, path->Attr("body")); + + while (context.size()) context.pop_back(); +} + +void TIRVisitorWithPath::EnterDef(const IterVar& iter_var, ObjectPath path) { + if (iter_var->dom.defined()) { + Visit(iter_var->dom, path->Attr("dom")); + } + EnterDef(iter_var->var, path->Attr("var")); +} + +void TIRVisitorWithPath::ExitDef(const IterVar& iter_var, ObjectPath path) { + ExitDef(iter_var->var, path->Attr("var")); +} + +void TIRVisitorWithPath::EnterDef(const Buffer& buffer, ObjectPath path) { + // Defining a buffer counts as using all parameters in the buffer + // (e.g. shape/strides). + Visit(buffer->data, path->Attr("data")); + Visit(buffer->shape, path->Attr("shape")); + Visit(buffer->strides, path->Attr("strides")); + Visit(buffer->elem_offset, path->Attr("elem_offset")); +} +void TIRVisitorWithPath::ExitDef(const Buffer& buffer, ObjectPath path) {} + +void TIRVisitorWithPath::Visit(const Buffer& buffer, ObjectPath path) { + // Using a buffer *also* counts as using all parameters in the buffer. + Visit(buffer->data, path->Attr("data")); + Visit(buffer->shape, path->Attr("shape")); + Visit(buffer->strides, path->Attr("strides")); + Visit(buffer->elem_offset, path->Attr("elem_offset")); +} + +void TIRVisitorWithPath::Visit(const BufferRegion& region, ObjectPath path) { + Visit(region->buffer, path->Attr("path")); + Visit(region->region, path->Attr("region")); +} + +void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, ObjectPath path) { + Visit(match->source, path->Attr("source")); + + // MatchBufferRegion define the match->buffer, but do not own the + // body in which the match->buffer is defined. Therefore, the + // definitions are handled in the BlockNode visitor. +} + +void TIRVisitorWithPath::Visit(const IterVar& iter_var, ObjectPath path) { + if (iter_var->dom.defined()) { + Visit(iter_var->dom, path->Attr("dom")); + } + Visit(iter_var->var, path->Attr("var")); +} + +void TIRVisitorWithPath::Visit(const Range& range, ObjectPath path) { + Visit(range->min, path->Attr("min")); + Visit(range->extent, path->Attr("extent")); +} + +void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); + auto context = WithDef(op->var, path->Attr("var")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); + + std::optional> context = std::nullopt; + if (auto ptr = op->node.as(); ptr && op->attr_key == attr::thread_extent) { + // Some attributes serve as a source of definition for the + // tir::Var they annotate. + Visit(ptr->dom, path->Attr("node")->Attr("dom")); + context = WithDef(ptr->var, path->Attr("node")->Attr("var")); + } else if (auto expr = op->node.as()) { + Visit(expr.value(), path->Attr("node")); + } + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) { + Visit(op->min, path->Attr("min")); + Visit(op->extent, path->Attr("extent")); + auto context = WithDef(op->loop_var, path->Attr("loop_var")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, ObjectPath path) { + Visit(op->condition, path->Attr("condition")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const AllocateNode* op, ObjectPath path) { + Visit(op->condition, path->Attr("condition")); + Visit(op->extents, path->Attr("extents")); + auto context = WithDef(op->buffer_var, path->Attr("buffer_var")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const AllocateConstNode* op, ObjectPath path) { + Visit(op->extents, path->Attr("extents")); + auto context = WithDef(op->buffer_var, path->Attr("buffer_var")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, ObjectPath path) { + auto context = WithDef(op->buffer, path->Attr("buffer")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); + Visit(op->buffer, path->Attr("buffer")); + Visit(op->indices, path->Attr("indices")); +} + +void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) { + Visit(op->condition, path->Attr("condition")); + Visit(op->bounds, path->Attr("bounds")); + auto context = WithDef(op->buffer, path->Attr("buffer")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, ObjectPath path) { + Visit(op->condition, path->Attr("condition")); + Visit(op->then_case, path->Attr("then_case")); + Visit(op->else_case, path->Attr("else_case")); +} + +void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ObjectPath path) { + Visit(op->condition, path->Attr("condition")); + Visit(op->message, path->Attr("message")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitStmt_(const ProducerStoreNode* op, ObjectPath path) { + Visit(op->indices, path->Attr("indices")); + Visit(op->value, path->Attr("value")); +} + +void TIRVisitorWithPath::VisitStmt_(const ProducerRealizeNode* op, ObjectPath path) { + Visit(op->bounds, path->Attr("bounds")); + Visit(op->body, path->Attr("body")); + Visit(op->condition, path->Attr("condition")); +} + +void TIRVisitorWithPath::VisitStmt_(const PrefetchNode* op, ObjectPath path) { + Visit(op->bounds, path->Attr("bounds")); +} + +void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, ObjectPath path) { + Visit(op->seq, path->Attr("seq")); +} + +void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); +} + +void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { + std::vector, DefContext, DefContext>> context; + + { + auto iter_path = path->Attr("iter_vars"); + for (size_t i = 0; i < op->iter_vars.size(); i++) { + context.push_back(WithDef(op->iter_vars[i], iter_path->ArrayIndex(i))); + } + } + Visit(op->reads, path->Attr("reads")); + Visit(op->writes, path->Attr("writes")); + + { + auto alloc_path = path->Attr("alloc_buffers"); + for (size_t i = 0; i < op->alloc_buffers.size(); i++) { + auto buffer_path = alloc_path->ArrayIndex(i); + auto buf = op->alloc_buffers[i]; + context.push_back(WithDef(buf->data, buffer_path->Attr("data"))); + context.push_back(WithDef(buf, buffer_path)); + } + } + + { + auto match_path = path->Attr("match_buffers"); + Visit(op->match_buffers, match_path); + + for (size_t i = 0; i < op->match_buffers.size(); i++) { + auto buf = op->match_buffers[i]->buffer; + auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer"); + context.push_back(WithDef(buf->data, buffer_path->Attr("data"))); + context.push_back(WithDef(buf, buffer_path)); + } + } + + Visit(op->init, path->Attr("init")); + Visit(op->body, path->Attr("body")); + + while (context.size()) context.pop_back(); +} + +void TIRVisitorWithPath::VisitStmt_(const BlockRealizeNode* op, ObjectPath path) { + Visit(op->iter_values, path->Attr("iter_values")); + Visit(op->predicate, path->Attr("predicate")); + Visit(op->block, path->Attr("block")); +} + +void TIRVisitorWithPath::VisitExpr_(const VarNode* op, ObjectPath path) {} + +void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, ObjectPath path) { + VisitExpr_(static_cast(op), path); +} + +void TIRVisitorWithPath::VisitExpr_(const AnyNode* op, ObjectPath path) {} + +void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, ObjectPath path) { + Visit(op->buffer, path->Attr("buffer")); + Visit(op->indices, path->Attr("indices")); +} + +void TIRVisitorWithPath::VisitExpr_(const ProducerLoadNode* op, ObjectPath path) { + Visit(op->indices, path->Attr("indices")); +} + +void TIRVisitorWithPath::VisitExpr_(const LetNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); + auto context = WithDef(op->var, path->Attr("var")); + Visit(op->body, path->Attr("body")); +} + +void TIRVisitorWithPath::VisitExpr_(const CallNode* op, ObjectPath path) { + if (auto gvar = op->op.as()) { + Visit(gvar.value(), path->Attr("op")); + } + Visit(op->args, path->Attr("args")); +} + +#define DEFINE_BINOP_VISIT_(OP) \ + void TIRVisitorWithPath::VisitExpr_(const OP* op, ObjectPath path) { \ + Visit(op->a, path->Attr("a")); \ + Visit(op->b, path->Attr("b")); \ + } + +DEFINE_BINOP_VISIT_(AddNode); +DEFINE_BINOP_VISIT_(SubNode); +DEFINE_BINOP_VISIT_(MulNode); +DEFINE_BINOP_VISIT_(DivNode); +DEFINE_BINOP_VISIT_(ModNode); +DEFINE_BINOP_VISIT_(FloorDivNode); +DEFINE_BINOP_VISIT_(FloorModNode); +DEFINE_BINOP_VISIT_(MinNode); +DEFINE_BINOP_VISIT_(MaxNode); +DEFINE_BINOP_VISIT_(EQNode); +DEFINE_BINOP_VISIT_(NENode); +DEFINE_BINOP_VISIT_(LTNode); +DEFINE_BINOP_VISIT_(LENode); +DEFINE_BINOP_VISIT_(GTNode); +DEFINE_BINOP_VISIT_(GENode); +DEFINE_BINOP_VISIT_(AndNode); +DEFINE_BINOP_VISIT_(OrNode); + +#undef DEFINE_BINOP_VISIT_ + +void TIRVisitorWithPath::VisitExpr_(const IntImmNode* op, ObjectPath path) {} +void TIRVisitorWithPath::VisitExpr_(const FloatImmNode* op, ObjectPath path) {} +void TIRVisitorWithPath::VisitExpr_(const StringImmNode* op, ObjectPath path) {} + +void TIRVisitorWithPath::VisitExpr_(const ReduceNode* op, ObjectPath path) { + Visit(op->axis, path->Attr("axis")); + Visit(op->source, path->Attr("source")); + Visit(op->init, path->Attr("init")); + Visit(op->condition, path->Attr("condition")); +} + +void TIRVisitorWithPath::VisitExpr_(const CastNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); +} + +void TIRVisitorWithPath::VisitExpr_(const NotNode* op, ObjectPath path) { + Visit(op->a, path->Attr("a")); +} + +void TIRVisitorWithPath::VisitExpr_(const SelectNode* op, ObjectPath path) { + Visit(op->condition, path->Attr("condition")); + Visit(op->true_value, path->Attr("true_value")); + Visit(op->false_value, path->Attr("false_value")); +} + +void TIRVisitorWithPath::VisitExpr_(const RampNode* op, ObjectPath path) { + Visit(op->base, path->Attr("base")); + Visit(op->stride, path->Attr("stride")); +} + +void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op, ObjectPath path) { + Visit(op->indices, path->Attr("indices")); + Visit(op->vectors, path->Attr("vectors")); +} + +void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, ObjectPath path) { + Visit(op->value, path->Attr("value")); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h new file mode 100644 index 000000000000..aa2adc037bfe --- /dev/null +++ b/src/tir/ir/tir_visitor_with_path.h @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/tir_visitor_with_path.h + * \brief Provide a TIR visitor that tracks the current location + */ +#ifndef TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ +#define TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! \brief Visit TIR while tracking the ObjectPath */ +class TIRVisitorWithPath : protected ExprFunctor, + protected StmtFunctor { + public: + template + void operator()(TObjectRef&& obj) { + Visit(std::forward(obj), ObjectPath::Root()); + } + + protected: + // Delegate to ExprFunctor::VisitExpr for PrimExpr, and any subclasses + inline void Visit(const PrimExpr& obj, ObjectPath path) { VisitExpr(obj, path); } + // Delegate to ExprFunctor::VisitStmt for Stmt, and any subclasses + inline void Visit(const Stmt& obj, ObjectPath path) { VisitStmt(obj, path); } + + // Visitors for TIR constructs that are neither PrimExpr nor Stmt + virtual void Visit(const IRModule& obj, ObjectPath path); + virtual void Visit(const PrimFunc& obj, ObjectPath path); + virtual void Visit(const GlobalVar& obj, ObjectPath path) {} + virtual void Visit(const Range& obj, ObjectPath path); + virtual void Visit(const Buffer& obj, ObjectPath path); + virtual void Visit(const BufferRegion& obj, ObjectPath path); + virtual void Visit(const MatchBufferRegion& obj, ObjectPath path); + virtual void Visit(const IterVar& obj, ObjectPath path); + + // Called when entering/exiting the scope of a GlobalVar definition. + virtual void EnterDef(const GlobalVar& var, ObjectPath path) {} + virtual void ExitDef(const GlobalVar& var, ObjectPath path) {} + + // Called when entering/exiting the scope of a tir::Var definition. + virtual void EnterDef(const Var& var, ObjectPath path) {} + virtual void ExitDef(const Var& var, ObjectPath path) {} + + // Called when entering/exiting the scope of an IterVar definition. + // By default, visits the `Range IterVarNode::dom`, then enters the + // scope of the internal `tir::Var`. + virtual void EnterDef(const IterVar& var, ObjectPath path); + virtual void ExitDef(const IterVar& var, ObjectPath path); + + // Called when entering/exiting the scope of a Buffer definition. + // By default, visits the buffer's data pointer, shape, strides, and + // elem_offset, which must be defined prior to defining the Buffer. + virtual void EnterDef(const Buffer& buffer, ObjectPath path); + virtual void ExitDef(const Buffer& buffer, ObjectPath path); + + // Utility to visit an array of nodes + template + inline void Visit(const Array& arr, ObjectPath path) { + for (size_t i = 0; i < arr.size(); i++) { + Visit(arr[i], path->ArrayIndex(i)); + } + } + + // Utility to visit an optional node nodes + template + inline void Visit(const Optional& opt, ObjectPath path) { + if (opt) { + Visit(opt.value(), path); + } + } + + using StmtFunctor::VisitStmt; + void VisitStmt_(const AttrStmtNode* op, ObjectPath path) override; + void VisitStmt_(const IfThenElseNode* op, ObjectPath path) override; + void VisitStmt_(const LetStmtNode* op, ObjectPath path) override; + void VisitStmt_(const ForNode* op, ObjectPath path) override; + void VisitStmt_(const WhileNode* op, ObjectPath path) override; + void VisitStmt_(const AllocateNode* op, ObjectPath path) override; + void VisitStmt_(const AllocateConstNode* op, ObjectPath path) override; + void VisitStmt_(const DeclBufferNode* op, ObjectPath path) override; + void VisitStmt_(const BufferStoreNode* op, ObjectPath path) override; + void VisitStmt_(const BufferRealizeNode* op, ObjectPath path) override; + void VisitStmt_(const AssertStmtNode* op, ObjectPath path) override; + void VisitStmt_(const ProducerStoreNode* op, ObjectPath path) override; + void VisitStmt_(const ProducerRealizeNode* op, ObjectPath path) override; + void VisitStmt_(const PrefetchNode* op, ObjectPath path) override; + void VisitStmt_(const SeqStmtNode* op, ObjectPath path) override; + void VisitStmt_(const EvaluateNode* op, ObjectPath path) override; + void VisitStmt_(const BlockNode* op, ObjectPath path) override; + void VisitStmt_(const BlockRealizeNode* op, ObjectPath path) override; + + using ExprFunctor::VisitExpr; + void VisitExpr_(const VarNode* op, ObjectPath path) override; + void VisitExpr_(const SizeVarNode* op, ObjectPath path) override; + void VisitExpr_(const BufferLoadNode* op, ObjectPath path) override; + void VisitExpr_(const ProducerLoadNode* op, ObjectPath path) override; + void VisitExpr_(const LetNode* op, ObjectPath path) override; + void VisitExpr_(const CallNode* op, ObjectPath path) override; + void VisitExpr_(const AddNode* op, ObjectPath path) override; + void VisitExpr_(const SubNode* op, ObjectPath path) override; + void VisitExpr_(const MulNode* op, ObjectPath path) override; + void VisitExpr_(const DivNode* op, ObjectPath path) override; + void VisitExpr_(const ModNode* op, ObjectPath path) override; + void VisitExpr_(const FloorDivNode* op, ObjectPath path) override; + void VisitExpr_(const FloorModNode* op, ObjectPath path) override; + void VisitExpr_(const MinNode* op, ObjectPath path) override; + void VisitExpr_(const MaxNode* op, ObjectPath path) override; + void VisitExpr_(const EQNode* op, ObjectPath path) override; + void VisitExpr_(const NENode* op, ObjectPath path) override; + void VisitExpr_(const LTNode* op, ObjectPath path) override; + void VisitExpr_(const LENode* op, ObjectPath path) override; + void VisitExpr_(const GTNode* op, ObjectPath path) override; + void VisitExpr_(const GENode* op, ObjectPath path) override; + void VisitExpr_(const AndNode* op, ObjectPath path) override; + void VisitExpr_(const OrNode* op, ObjectPath path) override; + void VisitExpr_(const ReduceNode* op, ObjectPath path) override; + void VisitExpr_(const CastNode* op, ObjectPath path) override; + void VisitExpr_(const NotNode* op, ObjectPath path) override; + void VisitExpr_(const SelectNode* op, ObjectPath path) override; + void VisitExpr_(const RampNode* op, ObjectPath path) override; + void VisitExpr_(const BroadcastNode* op, ObjectPath path) override; + void VisitExpr_(const ShuffleNode* op, ObjectPath path) override; + void VisitExpr_(const IntImmNode* op, ObjectPath path) override; + void VisitExpr_(const FloatImmNode* op, ObjectPath path) override; + void VisitExpr_(const StringImmNode* op, ObjectPath path) override; + void VisitExpr_(const AnyNode* op, ObjectPath path) override; + + // Utility to call EnterDef/ExitDef. Used in the implementation of + // WithDef. + template + class DefContext { + public: + DefContext(DefContext&& other) { swap(std::move(other)); } + DefContext& operator=(DefContext&& other) { + swap(std::move(other)); + return *this; + } + + DefContext(const DefContext&) = delete; + DefContext& operator=(const DefContext&) = delete; + ~DefContext() noexcept(false) { + // Checks performed when a definition goes out of scope may + // raise an exception. If the stack is already being unwound + // due to another exception being thrown, this would cause a + // segfault and terminate the program. By checking that no + // additional exceptions have been thrown between the + // construction of the DefContext and the destruction, we avoid + // this case and allow the first error to propagate upward. + if (self_ && std::uncaught_exceptions() == uncaught_exceptions_) { + self_->ExitDef(obj_, path_); + } + } + + private: + friend class TIRVisitorWithPath; + + DefContext(TIRVisitorWithPath* self, T obj, ObjectPath path) + : self_(self), obj_(obj), path_(path), uncaught_exceptions_(std::uncaught_exceptions()) { + self_->EnterDef(obj_, path_); + } + + void swap(DefContext&& other) { + std::swap(this->self_, other.self_); + std::swap(this->obj_, other.obj_); + std::swap(this->path_, other.path_); + std::swap(this->uncaught_exceptions_, other.uncaught_exceptions_); + } + + TIRVisitorWithPath* self_{nullptr}; + T obj_; + ObjectPath path_{ObjectPath::Root()}; + int uncaught_exceptions_{-1}; + }; + + // Utility to track the scope of a node's definition. + template + DefContext WithDef(T obj, ObjectPath path) { + return DefContext(this, obj, path); + } +}; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 6b681c07e5d5..e102e40dcccb 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -343,7 +343,42 @@ class IRConvertSSA final : public StmtExprMutator { } } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (const VarNode* v = op->node.as()) { + if (const IterVarNode* iter_var = op->node.as()) { + Range dom = iter_var->dom; + if (dom.defined()) { + auto min = VisitExpr(dom->min); + auto extent = VisitExpr(dom->extent); + if (!min.same_as(iter_var->dom->min) || !extent.same_as(iter_var->dom->extent)) { + dom = Range::FromMinExtent(min, extent); + } + } + + std::optional context = std::nullopt; + auto var = iter_var->var; + if (defined_.count(var.get())) { + context.emplace(this, var); + var = context->new_var; + } else { + defined_.insert(var.get()); + } + + IterVar new_iter_var; + if (dom.same_as(iter_var->dom) && var.same_as(iter_var->var)) { + new_iter_var = GetRef(iter_var); + } else { + new_iter_var = IterVar(dom, var, iter_var->iter_type, iter_var->thread_tag, iter_var->span); + } + + auto value = VisitExpr(op->value); + auto body = VisitStmt(op->body); + + if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span); + } + + } else if (const VarNode* v = op->node.as()) { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index b9fc056f1962..c90384fea73a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -57,7 +57,7 @@ class HostDeviceSplitter : public StmtMutator { private: Stmt SplitDeviceFunc(Stmt body, Target device_target) { auto [params, buffers_to_declare] = [&]() -> std::tuple, Array> { - VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); + VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true); use_def(body); // Sort first by variable type, then by variable name diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index 4f88cc8be1e1..4fee603062ac 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import pytest + import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import ir as I, tir as T def test_pass_simple(): @@ -54,5 +57,71 @@ def element_wise( assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False) +def test_error_for_out_of_scope_usage(): + """A variable may not be used after its scope ends""" + + @T.prim_func + def func(): + i = T.int32() + with T.LetStmt(42, var=i): + T.evaluate(i) + T.evaluate(i) + + with pytest.raises(ValueError, match="Invalid use of variable i at .* no longer in-scope."): + tvm.tir.analysis.verify_well_formed(func) + + +def test_error_for_nested_rebind_usage(): + """A variable may not be re-defined within the initial scope""" + + @T.prim_func + def func(): + i = T.int32() + with T.LetStmt(42, var=i): + with T.LetStmt(42, var=i): + T.evaluate(i) + + with pytest.raises( + ValueError, match="ill-formed, due to multiple nested definitions of variable i" + ): + tvm.tir.analysis.verify_well_formed(func) + + +def test_error_for_repeated_binding(): + """A variable may not be re-defined after the scope ends""" + + @T.prim_func + def func(): + i = T.int32() + with T.LetStmt(42, var=i): + T.evaluate(i) + with T.LetStmt(17, var=i): + T.evaluate(i) + + with pytest.raises(ValueError, match="multiple definitions of variable i"): + tvm.tir.analysis.verify_well_formed(func) + + +def test_error_for_cross_function_reuse(): + """A variable may not be re-defined in another function""" + + i = tvm.tir.Var("i", "int32") + + @I.ir_module + class mod: + @T.prim_func + def func1(): + with T.LetStmt(42, var=i): + T.evaluate(i) + + @T.prim_func + def func2(): + with T.LetStmt(42, var=i): + T.evaluate(i) + + with pytest.raises(ValueError, match="multiple definitions of variable i"): + tvm.tir.analysis.verify_well_formed(mod) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tir-transform/test_tir_transform_split_host_device.py index b61fcc66014e..790695f0555f 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py @@ -273,5 +273,44 @@ def main_kernel(): return mod +def test_dynamic_launch_thread(): + """Dynamic T.launch_thread may depend on host-side variable + + A dynamic parameter for `T.launch_thread` may have an extent that + is computed using variables outside of the `T.target` section. + + This is a regression test to catch a previous failure mode, in + which SplitHostDevice generated output with undefined variables, + if the only use of a variable occurred in the extent of a + `T.launch_thread` statement. + + While the lowering pass `LowerDeviceKernelLaunch` will hoist the + computation of the extent from the device kernel to the host + function, the IRModule must be well-defined at all stages of + lowering. Even if a variable is only used as part of a thread + extent, `SplitHostDevice` should treat it as a kernel parameter, to + provide a definition of the variable within the TIR device kernel. + """ + + @I.ir_module + class before: + @T.prim_func + def default_function(var_A: T.handle, var_B: T.handle, seq_len: T.int32): + T.func_attr({"target": T.target("cuda")}) + + A = T.match_buffer(var_A, [seq_len], "int32") + B = T.match_buffer(var_B, [seq_len], "int32") + + num_blocks: T.int32 = (seq_len + 127) // 128 + with T.attr(T.target("cuda"), "target", 0): + blockIdx_x = T.launch_thread("blockIdx.x", num_blocks) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + if blockIdx_x * 128 + threadIdx_x < seq_len: + B[blockIdx_x * 128 + threadIdx_x] = A[blockIdx_x * 128 + threadIdx_x] + + after = tvm.tir.transform.SplitHostDevice()(before) + tvm.tir.analysis.verify_well_formed(after) + + if __name__ == "__main__": tvm.testing.main() From 0ef00cb3388fc71fa8443e7900b170a78144610a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Dec 2023 15:09:15 -0600 Subject: [PATCH 02/10] Revert #16236 The buf reported in #16237 can be resolved by tracking variable usage in a thread extent. --- src/tir/transforms/ir_utils.cc | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index e102e40dcccb..132844211816 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -750,13 +750,8 @@ Pass ConvertSSA() { tir::IRConvertSSA converter; Map functions; bool made_change = false; - // FIXME: This is just a temporal workaround to ensure free vars - // in device function have the same pointer as the host function for (auto [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - if (!ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - continue; - } auto updated = converter.VisitPrimFunc(GetRef(ptr)); if (!updated.same_as(base_func)) { made_change = true; @@ -765,19 +760,6 @@ Pass ConvertSSA() { } functions.Set(gvar, base_func); } - for (auto [gvar, base_func] : mod->functions) { - if (auto* ptr = base_func.as()) { - if (ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - continue; - } - auto updated = converter.VisitPrimFunc(GetRef(ptr)); - if (!updated.same_as(base_func)) { - made_change = true; - base_func = updated; - } - functions.Set(gvar, base_func); - } - } if (made_change) { mod.CopyOnWrite()->functions = std::move(functions); } From 30f7b3bb52b1ed6afd37458a7a5811b3063c086a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Dec 2023 09:00:41 -0600 Subject: [PATCH 03/10] lint fixes --- src/tir/analysis/verify_well_formed.cc | 4 ++-- src/tir/ir/tir_visitor_with_path.cc | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 6adebdcc282c..59d89e6c2aa4 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -50,7 +50,7 @@ class Verifier : protected TIRVisitorWithPath { } protected: - Verifier(bool assert_on_error) : assert_on_error_(assert_on_error) {} + explicit Verifier(bool assert_on_error) : assert_on_error_(assert_on_error) {} /* \brief Helper class to handle the bool-or-assert handles * @@ -72,7 +72,7 @@ class Verifier : protected TIRVisitorWithPath { */ class VerifyStream { public: - VerifyStream(bool log_fatal) { + explicit VerifyStream(bool log_fatal) { if (log_fatal) { log_.emplace(); } diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 93e034b5d340..546c691bf426 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -24,8 +24,12 @@ #include "tir_visitor_with_path.h" +#include #include +#include +#include #include +#include namespace tvm { namespace tir { From 261709dc600a9cc7265ccf88b7efb390ebff0549 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Dec 2023 10:46:28 -0600 Subject: [PATCH 04/10] Update TIR well-formed checker for env thread SSA requirements Environment threads must reuse the same `tir::Var` across all `AttrStmt` instances in a `PrimFunc`, but must not reuse across separate `PrimFunc`s in an `IRModule`. --- src/tir/analysis/verify_well_formed.cc | 91 ++++++++++++++++--- src/tir/ir/tir_visitor_with_path.cc | 8 +- .../test_tir_analysis_verify_well_formed.py | 75 +++++++++++++++ 3 files changed, 158 insertions(+), 16 deletions(-) diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 59d89e6c2aa4..58eadb20fa01 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include "../ir/functor_common.h" @@ -227,10 +228,23 @@ class UndefinedVarVerifier : public Verifier { using Verifier::Verifier; private: + void Visit(const PrimFunc& prim_func, ObjectPath path) override { + Verifier::Visit(prim_func, path); + redefine_allowed_within_function_.clear(); + } + + void EnterDef(const IterVar& iter_var, ObjectPath path) override { + Verifier::EnterDef(iter_var, path); + if (iter_var->iter_type == IterVarType::kThreadIndex) { + redefine_allowed_within_function_.insert(iter_var->var); + } + } + void EnterDef(const Var& var, ObjectPath path) override { + bool redefine_is_allowed = redefine_allowed_within_function_.count(var); { auto it = currently_defined_.find(var); - Verify(it == currently_defined_.end()) + Verify(it == currently_defined_.end() || redefine_is_allowed) << "ValueError: " << "TIR is ill-formed, " << "due to multiple nested definitions of variable " << var @@ -239,7 +253,7 @@ class UndefinedVarVerifier : public Verifier { { auto it = previously_defined_.find(var); - Verify(it == previously_defined_.end()) + Verify(it == previously_defined_.end() || redefine_is_allowed) << "ValueError: " << "TIR is ill-formed, " << "due to multiple definitions of variable " << var << ". It was first defined at " @@ -259,21 +273,72 @@ class UndefinedVarVerifier : public Verifier { void VisitExpr_(const VarNode* op, ObjectPath path) override { auto var = GetRef(op); - auto prev_def = previously_defined_.find(var); - Verify(prev_def == previously_defined_.end()) - << "ValueError: " - << "Invalid use of variable " << var << " at " << path << ". " - << "While this variable was previously defined at " << prev_def->second - << ", this definition is no longer in-scope."; - auto active_def = currently_defined_.find(var); - Verify(active_def != currently_defined_.end()) - << "ValueError: " - << "Invalid use of undefined variable " << var << " at " << path; + auto verify = Verify(active_def != currently_defined_.end()); + verify << "ValueError: " + << "Invalid use of undefined variable " << var << " at " << path << "."; + + // Check if there was a previous definition, and append the + // location to the error message if there was. This is to aid in + // debugging, by distinguishing between a variable that is + // currently out-of-scope, and a variable that never had a + // definition in the first place. + if (auto prev_def = previously_defined_.find(var); prev_def != previously_defined_.end()) { + verify << ". While this variable was previously defined at " << prev_def->second + << ", this definition is no longer in-scope."; + } } + // Variables that are defined in the currently-visited scope. std::unordered_map currently_defined_; + + // Variables that were previously defined, and are now out of scope. std::unordered_map previously_defined_; + + // Special variables that are allowed to be re-defined, so long as + // that re-definition occurs within the same PrimFunc. For example + std::unordered_set redefine_allowed_within_function_; +}; + +/* \brief Verify unique tir::Var for each environment thread + * + * Environment threads, such as CUDA's `threadIdx.x`, are defined in + * TIR using an `AttrStmt` with the key `attr::thread_extent`. A + * `PrimFunc` may contain multiple such attributes for the same + * environment thread. However, all such attributes must use the same + * `tir::Var` for a given thread. + */ +class SingleEnvThreadVerifier : public Verifier { + public: + using Verifier::Verifier; + + private: + void Visit(const PrimFunc& prim_func, ObjectPath path) override { + Verifier::Visit(prim_func, path); + env_thread_vars_.clear(); + } + + void EnterDef(const IterVar& iter_var, ObjectPath path) override { + if (iter_var->iter_type == IterVarType::kThreadIndex) { + if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) { + const auto& [prev_var, prev_path] = it->second; + Verify(prev_var.same_as(iter_var->var)) + << "ValueError: " + << "PrimFunc uses multiple distinct TIR variables " + << " for the environment thread \"" << iter_var->thread_tag << "\". " + << "While multiple tir::AttrStmt may define the same environment thread, " + << "all definitions within a single PrimFunc must share the same tir::Var. " + << "Binding of environment thread \"" << iter_var->thread_tag + << "\" to the TIR variable " << iter_var->var << " at " << path + << " conflicts with the previous binding to the TIR variable " << prev_var << " at " + << path; + } else { + env_thread_vars_.insert({iter_var->thread_tag, {iter_var->var, path}}); + } + } + } + + std::unordered_map> env_thread_vars_; }; bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { @@ -282,6 +347,7 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { } if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false; + if (!SingleEnvThreadVerifier::Verify(func, assert_mode)) return false; // TODO(Siyuan): add more checks here. return true; @@ -298,6 +364,7 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { } if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false; + if (!SingleEnvThreadVerifier::Verify(mod, assert_mode)) return false; return true; } diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 546c691bf426..8918f4c716e6 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -200,12 +200,12 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) { void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { Visit(op->value, path->Attr("value")); - std::optional> context = std::nullopt; - if (auto ptr = op->node.as(); ptr && op->attr_key == attr::thread_extent) { + std::optional> context = std::nullopt; + if (auto iter_var = op->node.as(); + iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) { // Some attributes serve as a source of definition for the // tir::Var they annotate. - Visit(ptr->dom, path->Attr("node")->Attr("dom")); - context = WithDef(ptr->var, path->Attr("node")->Attr("var")); + context = WithDef(iter_var.value(), path->Attr("node")); } else if (auto expr = op->node.as()) { Visit(expr.value(), path->Attr("node")); } diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index 4fee603062ac..f7cc8eff6d76 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -123,5 +123,80 @@ def func2(): tvm.tir.analysis.verify_well_formed(mod) +def test_reuse_of_env_thread_in_function_is_well_formed(): + """An env thread may be reused within a PrimFunc + + The `T.env_thread` has unique semantics, and may be defined at + multiple locations without the TIR being considered ill-formed. + """ + + @T.prim_func + def func(A: T.Buffer([256], "float32")): + threadIdx_x = T.env_thread("threadIdx.x") + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 1.0 + + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + tvm.tir.analysis.verify_well_formed(func) + + +def test_reuse_of_env_thread_in_function_is_mandatory(): + """An env thread may be reused within a PrimFunc + + Not only are environment threads allowed to have multiple + definition sites, it is mandatory for them to have multiple + definition sites. If a PrimFunc contains more than one + `"thread_extent"` with the same name, but with different `tir.Var` + instances, it is ill-formed. + """ + + @T.prim_func + def func(A: T.Buffer([256], "float32")): + with T.launch_thread("threadIdx.x", 256) as threadIdx_x: + A[threadIdx_x] = A[threadIdx_x] + 1.0 + + with T.launch_thread("threadIdx.x", 256) as threadIdx_x: + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + with pytest.raises(ValueError): + tvm.tir.analysis.verify_well_formed(func) + + +def test_reuse_of_env_thread_across_functions_is_ill_formed(): + """An env thread may not be reused across PrimFunc + + However, each function must have its own `tir.Var` representing + the environment thread, and may not share these variables across + PrimFuncs. + """ + + threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") + + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + with pytest.raises(ValueError, match="multiple definitions of variable threadIdx_x"): + tvm.tir.analysis.verify_well_formed(mod) + + if __name__ == "__main__": tvm.testing.main() From eb6bfb945f580589eac7c1df0fca1351a43c75fd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Dec 2023 11:13:36 -0600 Subject: [PATCH 05/10] Update ConvertSSA to handle environment threads' SSA requirements --- src/tir/transforms/ir_utils.cc | 27 ++- .../test_tir_analysis_verify_well_formed.py | 4 +- .../test_tir_transform_convert_ssa.py | 216 ++++++++++++++++++ 3 files changed, 241 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 132844211816..a85bde6787f0 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -207,6 +207,7 @@ class IRConvertSSA final : public StmtExprMutator { while (redefines.size()) { redefines.pop_back(); } + function_scope_var_remap_.clear(); return func; } @@ -259,6 +260,9 @@ class IRConvertSSA final : public StmtExprMutator { Var GetRemappedVar(Var var) { if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) { return it->second.back(); + } else if (auto it = function_scope_var_remap_.find(var.get()); + it != function_scope_var_remap_.end()) { + return it->second; } else { return var; } @@ -353,12 +357,23 @@ class IRConvertSSA final : public StmtExprMutator { } } - std::optional context = std::nullopt; - auto var = iter_var->var; - if (defined_.count(var.get())) { - context.emplace(this, var); - var = context->new_var; + Var var = iter_var->var; + if (auto it = function_scope_var_remap_.find(var.get()); + it != function_scope_var_remap_.end()) { + var = it->second; + } else if (defined_.count(var.get())) { + Var new_var = [&]() { + if (var->type_annotation.defined()) { + return Var(var->name_hint, var->type_annotation); + } else { + return Var(var->name_hint, var->dtype); + } + }(); + + function_scope_var_remap_.insert({var.get(), new_var}); + var = new_var; } else { + function_scope_var_remap_.insert({var.get(), var}); defined_.insert(var.get()); } @@ -437,6 +452,8 @@ class IRConvertSSA final : public StmtExprMutator { std::unordered_map> scope_; std::unordered_set defined_; std::unordered_map> buf_remap_; + + std::unordered_map function_scope_var_remap_; }; Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index f7cc8eff6d76..e839f44b3306 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -67,7 +67,9 @@ def func(): T.evaluate(i) T.evaluate(i) - with pytest.raises(ValueError, match="Invalid use of variable i at .* no longer in-scope."): + with pytest.raises( + ValueError, match="Invalid use of undefined variable i at .* no longer in-scope." + ): tvm.tir.analysis.verify_well_formed(func) diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 38a93b199e44..140adcd35bd2 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -269,5 +269,221 @@ def expected(self): return tvm.IRModule({"func_a": self._make_func(), "func_b": self._make_func()}) +class TestKeepDuplicateThreadIdxInSameFunction(BaseBeforeAfter): + """Environment threads are treated as being at function scope + + The `"thread_extent"` attribute has some unique semantics. It + serves as the definition of the `tir::Var` representing the + environment thread (e.g. `threadIdx.x` in CUDA). However, + multiple `"thread_extent"` attributes may co-exist in the same + PrimFunc. For the purpose of variable scope, use of the + `tir::Var` is only allowed within the body of the `AttrStmt`. + However, for the purpose of well-formed-ness, all + `"thread_extent"` attributes must use the same IterVar instance + (e.g. `WarpIndexFinder` in `lower_warp_memory.cc` may throw an + error if multiple IterVar instances occur). + + If there are multiple `AttrStmt` with key `"thread_extent"` in a + single function (represented in TVMScript as `T.launch_thread`), + these should be treated as a definition of a single variable at + function scope, and should not be de-duplicated. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer([256], "float32")): + threadIdx_x = T.env_thread("threadIdx.x") + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 1.0 + + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + return mod + + expected = before + + +class TestDeDuplicateThreadIdxAcrossMultipleFunctions(BaseBeforeAfter): + """Environment threads are treated as being at function scope + + See `TestKeepDuplicateThreadIdxInSameFunction` for background + information. + + If there are multiple functions in an IRModule, the `AttrStmt` + with key `"thread_extent"` in a single function (represented in + TVMScript as `T.launch_thread`), these should be treated as a + definition of a single variable at function scope, and should not + be de-duplicated. + + For this test case, the `AttrStmt` for `"thread_extent"` are + written explicitly, without using the usual `T.env_thread` and + `T.launch_thread`, as they cannot represent the duplciate + Var/IterVar usage across the two PrimFuncs. + """ + + def before(self): + threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") + + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + threadIdx_x = T.int32() + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + threadIdx_x = T.int32() + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + return mod + + +class TestDeDuplicateThreadIdxIterVarAcrossMultipleFunctions(BaseBeforeAfter): + """Environment threads are treated as being at function scope + + Like `TestDeDuplicateThreadIdxAcrossMultipleFunctions`, except the + `IterVar` for the environment thread is duplicated across multiple + PrimFuncs, not just the `tir.Var` inside the `IterVar`. + """ + + def before(self): + threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") + iter_var = tvm.tir.IterVar( + tvm.ir.Range(0, 256), threadIdx_x, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" + ) + + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + T.attr(iter_var, "thread_extent", 256) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + T.attr(iter_var, "thread_extent", 256) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + threadIdx_x = T.int32() + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + threadIdx_x = T.int32() + T.attr( + T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), + "thread_extent", + 256, + ) + A[threadIdx_x] = A[threadIdx_x] + T.float32(1) + + return mod + + +class TestThreadIdxReusedWithinAndAcrossFunctions(BaseBeforeAfter): + """Environment threads are treated as being at function scope + + A combination of + TestDeDuplicateThreadIdxIterVarAcrossMultipleFunctions and + TestKeepDuplicateThreadIdxInSameFunction. The re-use within a + function should be maintained, while re-use across functions is + de-duplicated. + """ + + def before(self): + threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") + iter_var = tvm.tir.IterVar( + tvm.ir.Range(0, 256), threadIdx_x, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" + ) + + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + with T.attr(iter_var, "thread_extent", 256): + A[threadIdx_x] = A[threadIdx_x] + 1.0 + with T.attr(iter_var, "thread_extent", 256): + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + with T.attr(iter_var, "thread_extent", 256): + A[threadIdx_x] = A[threadIdx_x] + 1.0 + with T.attr(iter_var, "thread_extent", 256): + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def kernel_1(A: T.Buffer([256], "float32")): + threadIdx_x = T.env_thread("threadIdx.x") + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 1.0 + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + @T.prim_func + def kernel_2(A: T.Buffer([256], "float32")): + threadIdx_x = T.env_thread("threadIdx.x") + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 1.0 + with T.launch_thread(threadIdx_x, 256): + A[threadIdx_x] = A[threadIdx_x] + 2.0 + + return mod + + if __name__ == "__main__": tvm.testing.main() From c4e21d55268e8ed4223d2c83e7a450d7a007c8bc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Dec 2023 12:15:26 -0600 Subject: [PATCH 06/10] lint fix --- src/tir/ir/tir_visitor_with_path.cc | 1 - src/tir/ir/tir_visitor_with_path.h | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 8918f4c716e6..3ca126304a71 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index aa2adc037bfe..dd0da1fe77a9 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -29,6 +29,7 @@ #include #include +#include namespace tvm { namespace tir { From 6f1149885292dcb934f7b2b0b50a03a46179b475 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 29 Dec 2023 10:06:28 -0600 Subject: [PATCH 07/10] Updated docstrings for VerifyWellFormed --- include/tvm/tir/analysis.h | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 701e2a5143e8..c4ae5d573be9 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -307,13 +307,40 @@ TVM_DLL Map> DetectBufferAccessLCA(const PrimFunc& func); /*! * \brief Verify if the given TIR is well-formed. The verification includes: - * - Check if expressions not contain vars that is defined outside the block. + * + * - All variables are defined prior to their point of use. + * + * - No variables are used outside of the scope of their definition. + * + * - Each variable has a single point of definition. + * + * - Expressions within a tir::Block may not reference variables + * defined outside the block. For example, for a block with iter + * vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement + * `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop + * variables `i` and `j` instead of the block variables `vi` and + * `vj`. + * * \param func The PrimFunc to be verified. * \param assert_mode The indicator if it raises an error when the function is not well-formed. * \return Whether it is a well-formed TIR function. */ TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true); +/*! + * \brief Verify if the TIR in the given IRMOdule is well-formed. + * + * In addition to the checks performed for each PrimFunc (see above), + * the following checks are performed: + * + * - The same TIR variable may not be defined in more than one function + * + * \param mod The IRModule to be verified. + * \param assert_mode The indicator if it raises an error when the function is not well-formed. + * \return Whether it is a well-formed TIR module. + */ +TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool assert_mode = true); + /*! * \brief Find the entry function of the given IRModule, i.e, functions marked by * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. From 7c61653a5b44aae865d250df9671da9b36aaf172 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 29 Dec 2023 10:06:52 -0600 Subject: [PATCH 08/10] Rely on script.Complete for read/writes Avoids issue in cortexm unit tests resulting from read/write annotations being present in the root block, followed by application of BindParams. --- src/te/operation/create_primfunc.cc | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index dc0b1fbfb86f..d862d9f67604 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -424,15 +424,12 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf } } - // Step 3. Collect Access Region - Array reads, writes; - for (const te::Tensor& tensor : extern_op->inputs) { - // We have ICHECK before so it is not needed here. - reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor])); - } - for (const Buffer& buffer : extern_op->output_placeholders) { - writes.push_back(BufferRegion::FullRegion(buffer)); - } + // The access region does not need to be collected here, as it will + // be generated with the later application of "script.Complete" in + // GenerateAndCompletePrimFunc. Waiting until later also handles + // the case where there is only a single BlockNode, which then + // becomes the root Block of the function, and should not have + // reads/writes filled in. BufferSubstituter substituter(var_map, input_buffer_map); Stmt body = substituter(extern_op->body); @@ -442,8 +439,8 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf /*predicate=*/Bool(true), /*block=*/ Block(/*iter_vars=*/{}, - /*reads=*/std::move(reads), - /*writes=*/std::move(writes), + /*reads=*/{}, + /*writes=*/{}, /*name_hint=*/info->FreshName(extern_op->name), /*body=*/std::move(body), /*init=*/NullOpt, From 2ae62d83143454b147092aca5f94b5ccb80a7dee Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 29 Dec 2023 10:07:42 -0600 Subject: [PATCH 09/10] Typo fix --- src/tir/ir/tir_visitor_with_path.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 3ca126304a71..50c8b8f5254e 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -166,7 +166,7 @@ void TIRVisitorWithPath::Visit(const Buffer& buffer, ObjectPath path) { } void TIRVisitorWithPath::Visit(const BufferRegion& region, ObjectPath path) { - Visit(region->buffer, path->Attr("path")); + Visit(region->buffer, path->Attr("buffer")); Visit(region->region, path->Attr("region")); } From 954cafffbe433d8812afed30e964c2801c9bebfc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Jan 2024 10:56:23 -0600 Subject: [PATCH 10/10] Added structural equal comparison in unit test --- .../test_tir_transform_split_host_device.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tir-transform/test_tir_transform_split_host_device.py index 790695f0555f..6adfbeb81d54 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py @@ -308,8 +308,41 @@ def default_function(var_A: T.handle, var_B: T.handle, seq_len: T.int32): if blockIdx_x * 128 + threadIdx_x < seq_len: B[blockIdx_x * 128 + threadIdx_x] = A[blockIdx_x * 128 + threadIdx_x] + @I.ir_module + class expected: + @T.prim_func + def default_function(var_A: T.handle, var_B: T.handle, seq_len: T.int32): + T.func_attr({"target": T.target("cuda")}) + A = T.match_buffer(var_A, (seq_len,), "int32") + B = T.match_buffer(var_B, (seq_len,), "int32") + num_blocks: T.int32 = (seq_len + 127) // 128 + expected.default_function_kernel(A.data, B.data, num_blocks, seq_len) + + @T.prim_func(private=True) + def default_function_kernel( + A_data: T.handle("int32"), + B_data: T.handle("int32"), + num_blocks: T.int32, + seq_len: T.int32, + ): + T.func_attr( + { + "target": T.target("cuda"), + "tir.is_global_func": True, + "tir.noalias": True, + } + ) + A = T.decl_buffer(seq_len, "int32", data=A_data) + B = T.decl_buffer(seq_len, "int32", data=B_data) + blockIdx_x = T.launch_thread("blockIdx.x", num_blocks) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + if blockIdx_x * 128 + threadIdx_x < seq_len: + B[blockIdx_x * 128 + threadIdx_x] = A[blockIdx_x * 128 + threadIdx_x] + after = tvm.tir.transform.SplitHostDevice()(before) + tvm.tir.analysis.verify_well_formed(after) + tvm.ir.assert_structural_equal(expected, after) if __name__ == "__main__":