Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] In SplitHostDevice, check for variables in thread extents #16250

Merged
merged 10 commits into from
Jan 3, 2024
29 changes: 28 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,40 @@ TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);

/*!
* \brief Verify if the given TIR is well-formed. The verification includes:
* - Check if expressions not contain vars that is defined outside the block.
*
* - All variables are defined prior to their point of use.
*
* - No variables are used outside of the scope of their definition.
*
* - Each variable has a single point of definition.
*
* - Expressions within a tir::Block may not reference variables
* defined outside the block. For example, for a block with iter
* vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement
* `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop
* variables `i` and `j` instead of the block variables `vi` and
* `vj`.
*
* \param func The PrimFunc to be verified.
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
* \return Whether it is a well-formed TIR function.
*/
TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);

/*!
* \brief Verify if the TIR in the given IRMOdule is well-formed.
*
* In addition to the checks performed for each PrimFunc (see above),
* the following checks are performed:
*
* - The same TIR variable may not be defined in more than one function
*
* \param mod The IRModule to be verified.
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
* \return Whether it is a well-formed TIR module.
*/
TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool assert_mode = true);

/*!
* \brief Find the entry function of the given IRModule, i.e, functions marked by
* `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
Expand Down
19 changes: 8 additions & 11 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,12 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
}
}

// Step 3. Collect Access Region
Array<BufferRegion> reads, writes;
for (const te::Tensor& tensor : extern_op->inputs) {
// We have ICHECK before so it is not needed here.
reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor]));
}
for (const Buffer& buffer : extern_op->output_placeholders) {
writes.push_back(BufferRegion::FullRegion(buffer));
}
// The access region does not need to be collected here, as it will
// be generated with the later application of "script.Complete" in
// GenerateAndCompletePrimFunc. Waiting until later also handles
// the case where there is only a single BlockNode, which then
// becomes the root Block of the function, and should not have
// reads/writes filled in.

BufferSubstituter substituter(var_map, input_buffer_map);
Stmt body = substituter(extern_op->body);
Expand All @@ -442,8 +439,8 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
/*predicate=*/Bool(true),
/*block=*/
Block(/*iter_vars=*/{},
/*reads=*/std::move(reads),
/*writes=*/std::move(writes),
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/info->FreshName(extern_op->name),
/*body=*/std::move(body),
/*init=*/NullOpt,
Expand Down
214 changes: 214 additions & 0 deletions src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,97 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <exception>
#include <optional>
#include <tuple>
#include <variant>

#include "../ir/functor_common.h"
#include "../ir/tir_visitor_with_path.h"
#include "tvm/ir/module.h"

namespace tvm {
namespace tir {

namespace {

template <typename DerivedVerifier>
class Verifier : protected TIRVisitorWithPath {
public:
template <typename TirNodeRef>
static bool Verify(const TirNodeRef& node, bool assert_on_error) {
DerivedVerifier verifier(assert_on_error);
verifier(node);
return !verifier.has_error_;
}

protected:
explicit Verifier(bool assert_on_error) : assert_on_error_(assert_on_error) {}

/* \brief Helper class to handle the bool-or-assert handles
*
* Each verifier can either return a boolean, or assert on failure.
* To avoid needing to duplicate this logic at every step, the
* Verify() method can be used. Similar to `LOG(FATAL)` or
* `LOG(DEBUG)`, it returns an object that can accept streamed
* context information.
*
* If the error should be raised, then the context is collected
* identically to `LOG(FATAL)`. If a boolean is returned, or if the
* condition passes, then the streamed context is discarded.
*
* Usage:
*
* Verify(value == expected_value)
* << "ValueError: " << value
* << " was not the expected value of " << expected_value;
*/
class VerifyStream {
public:
explicit VerifyStream(bool log_fatal) {
if (log_fatal) {
log_.emplace();
}
}

VerifyStream(const VerifyStream&) = delete;
VerifyStream& operator=(const VerifyStream&) = delete;
VerifyStream(VerifyStream&& other) { std::swap(log_, other.log_); }
VerifyStream& operator=(VerifyStream&& other) {
std::swap(log_, other.log_);
return *this;
}

template <typename T>
VerifyStream& operator<<(T&& t) {
if (log_.has_value()) {
log_.value() << std::forward<T>(t);
}
return *this;
}

~VerifyStream() noexcept(false) {
if (log_.has_value()) {
LOG(FATAL) << log_->str();
}
}

std::optional<std::ostringstream> log_{std::nullopt};
};

// TODO(Lunderberg): Add the filename/linenum with
// std::source_location when C++20 is available.
VerifyStream Verify(bool condition) {
has_error_ = has_error_ || !condition;
return VerifyStream(!condition && assert_on_error_);
}

bool assert_on_error_;
bool has_error_{false};
};

} // namespace

/*! \brief Verify all Expr inside the block does not contain:
* 1. loop vars outside the current block.
* 2. block vars of parent blocks.
Expand Down Expand Up @@ -135,10 +220,135 @@ class BlockVarAccessVerifier : public StmtExprVisitor {
bool has_error_{false};
};

class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
public:
// Until templated-this arrives in C++23, the CRTP can't inject a
// constructor into the child class. Therefore, must explicitly add
// the constructor.
using Verifier::Verifier;

private:
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
redefine_allowed_within_function_.clear();
}

void EnterDef(const IterVar& iter_var, ObjectPath path) override {
Verifier::EnterDef(iter_var, path);
if (iter_var->iter_type == IterVarType::kThreadIndex) {
redefine_allowed_within_function_.insert(iter_var->var);
}
}

void EnterDef(const Var& var, ObjectPath path) override {
bool redefine_is_allowed = redefine_allowed_within_function_.count(var);
{
auto it = currently_defined_.find(var);
Verify(it == currently_defined_.end() || redefine_is_allowed)
<< "ValueError: "
<< "TIR is ill-formed, "
<< "due to multiple nested definitions of variable " << var
<< ". It was first defined at " << it->second << ", and was re-defined at " << path;
}

{
auto it = previously_defined_.find(var);
Verify(it == previously_defined_.end() || redefine_is_allowed)
<< "ValueError: "
<< "TIR is ill-formed, "
<< "due to multiple definitions of variable " << var << ". It was first defined at "
<< it->second << ", and was later re-defined at " << path;
}

currently_defined_.insert({var, path});
}

void ExitDef(const Var& var, ObjectPath path) override {
auto active_def = currently_defined_.find(var);

currently_defined_.erase(active_def);
previously_defined_.insert({var, path});
}

void VisitExpr_(const VarNode* op, ObjectPath path) override {
auto var = GetRef<Var>(op);

auto active_def = currently_defined_.find(var);
auto verify = Verify(active_def != currently_defined_.end());
verify << "ValueError: "
<< "Invalid use of undefined variable " << var << " at " << path << ".";

// Check if there was a previous definition, and append the
// location to the error message if there was. This is to aid in
// debugging, by distinguishing between a variable that is
// currently out-of-scope, and a variable that never had a
// definition in the first place.
if (auto prev_def = previously_defined_.find(var); prev_def != previously_defined_.end()) {
verify << ". While this variable was previously defined at " << prev_def->second
<< ", this definition is no longer in-scope.";
}
}

// Variables that are defined in the currently-visited scope.
std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> currently_defined_;

// Variables that were previously defined, and are now out of scope.
std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> previously_defined_;

// Special variables that are allowed to be re-defined, so long as
// that re-definition occurs within the same PrimFunc. For example
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> redefine_allowed_within_function_;
};

/* \brief Verify unique tir::Var for each environment thread
*
* Environment threads, such as CUDA's `threadIdx.x`, are defined in
* TIR using an `AttrStmt` with the key `attr::thread_extent`. A
* `PrimFunc` may contain multiple such attributes for the same
* environment thread. However, all such attributes must use the same
* `tir::Var` for a given thread.
*/
class SingleEnvThreadVerifier : public Verifier<SingleEnvThreadVerifier> {
public:
using Verifier::Verifier;

private:
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
env_thread_vars_.clear();
}

void EnterDef(const IterVar& iter_var, ObjectPath path) override {
if (iter_var->iter_type == IterVarType::kThreadIndex) {
if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) {
const auto& [prev_var, prev_path] = it->second;
Verify(prev_var.same_as(iter_var->var))
<< "ValueError: "
<< "PrimFunc uses multiple distinct TIR variables "
<< " for the environment thread \"" << iter_var->thread_tag << "\". "
<< "While multiple tir::AttrStmt may define the same environment thread, "
<< "all definitions within a single PrimFunc must share the same tir::Var. "
<< "Binding of environment thread \"" << iter_var->thread_tag
<< "\" to the TIR variable " << iter_var->var << " at " << path
<< " conflicts with the previous binding to the TIR variable " << prev_var << " at "
<< path;
} else {
env_thread_vars_.insert({iter_var->thread_tag, {iter_var->var, path}});
}
}
}

std::unordered_map<String, std::tuple<Var, ObjectPath>> env_thread_vars_;
};

bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
return false;
}

if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false;
if (!SingleEnvThreadVerifier::Verify(func, assert_mode)) return false;

// TODO(Siyuan): add more checks here.
return true;
}
Expand All @@ -152,6 +362,10 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
}
}
}

if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false;
if (!SingleEnvThreadVerifier::Verify(mod, assert_mode)) return false;

return true;
}

Expand Down
Loading
Loading