Skip to content

Commit

Permalink
[Refactor] Move VarUseDefAnalysis to header file (#14185)
Browse files Browse the repository at this point in the history
# Motivation
`UndefinedVars` is a frequently used function in our codebase and its implementation relies on `VarUseDefAnalysis` class which is more general, currently we expose `UndefinedVars` in `analysis.h`, but both the definitions of `UndefinedVars` and `VarUseDefAnalysis` resides in `split_host_device.cc`.

This PR moves `VarUseDefAnalysis` class to `analysis.h` so that developers can use it in other files that requires use/def analysis than `split_host_devices.cc`. We create a `var_use_def_analysis.cc` under `src/src/analysis` for the implementations of both `UndefinedVars` and `VarUseDefAnalysis`.

# Notes
We rename `VarUseDefAnalysis` to `VarUseDefAnalyzer`.
  • Loading branch information
yzh119 authored Mar 7, 2023
1 parent 2c4af88 commit 082c443
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 168 deletions.
10 changes: 9 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TVM_DLL double EstimateTIRFlops(const IRModule& mod);

/*!
* \brief Find undefined vars in the statement.
* \param stmt The function to be checked.
* \param stmt The statement to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Expand All @@ -107,6 +107,14 @@ TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
*/
TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);

/*!
* \brief Find undefined vars in the expression.
* \param expr The expression to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr, const Array<Var>& defs);

/*!
* \brief Analyze the side effect
* \param expr The expression to be checked.
Expand Down
176 changes: 176 additions & 0 deletions src/tir/analysis/var_use_def_analysis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file var_use_def_analysis.cc
* \brief Classes and functions to analyze var defition and usage.
*/
#include "var_use_def_analysis.h"
namespace tvm {
namespace tir {

VarUseDefAnalyzer::VarUseDefAnalyzer(const Array<Var>& defined_vars, bool visit_thread_extent)
: visit_thread_extent_(visit_thread_extent) {
for (const Var v : defined_vars) {
use_count_[v.get()] = 0;
}
}

void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
if (!use_count_.count(iv->var.get())) {
this->HandleDef(iv->var.get());
}

if (visit_thread_extent_) {
this->VisitExpr(op->value);
}

this->VisitStmt(op->body);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}

void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) {
this->HandleDef(op->var.get());
StmtExprVisitor::VisitStmt_(op);
}

void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) {
this->HandleDef(op->loop_var.get());
StmtExprVisitor::VisitStmt_(op);
}

void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) {
this->HandleDef(op->buffer_var.get());
StmtExprVisitor::VisitStmt_(op);
}

void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) {
this->HandleDef(op->buffer_var.get());
StmtExprVisitor::VisitStmt_(op);
}

void VarUseDefAnalyzer::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}

void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) {
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var.get());
this->VisitExpr(op->value);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
this->HandleDef(op->var.get());
let_binding_[op->var.get()] = op;
}
this->VisitExpr(op->body);
}

void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) {
this->HandleUse(op);
StmtExprVisitor::VisitExpr_(op);
}

void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) {
for (const auto& iv : op->axis) {
this->HandleDef(iv->var.get());
}
StmtExprVisitor::VisitExpr_(op);
}

void VarUseDefAnalyzer::VisitExpr_(const LoadNode* op) {
LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
}

void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}

void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) {
this->HandleUse(buffer->data.get());
auto visit_arr = [&](Array<PrimExpr> arr) {
for (const auto& element : arr) {
this->VisitExpr(element);
}
};

visit_arr(buffer->shape);
visit_arr(buffer->strides);
}

void VarUseDefAnalyzer::HandleDef(const VarNode* v) {
ICHECK(!def_count_.count(v)) << "variable " << v->name_hint
<< " has already been defined, the Stmt is not SSA";
ICHECK(!use_count_.count(v)) << "variable " << v->name_hint
<< " has been used before definition!";
use_count_[v] = 0;
def_count_[v] = 1;
}

void VarUseDefAnalyzer::HandleUse(const VarNode* v) {
auto it = use_count_.find(v);
if (it != use_count_.end()) {
if (it->second >= 0) {
++it->second;
}
} else {
undefined_.push_back(GetRef<Var>(v));
use_count_[v] = -1;
}
}

Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalyzer m(args);
m(stmt);
return m.undefined_;
}

Array<Var> UndefinedVars(const PrimExpr& expr) {
VarUseDefAnalyzer m({});
m(expr);
return m.undefined_;
}

Array<Var> UndefinedVars(const PrimExpr& expr, const Array<Var>& args) {
VarUseDefAnalyzer m(args);
m(expr);
return m.undefined_;
}

} // namespace tir
} // namespace tvm
89 changes: 89 additions & 0 deletions src/tir/analysis/var_use_def_analysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/src/tir/analysis/var_use_def_analyzer.h
* \brief Variable definition and usage analysis class.
*/
#ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_

#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_map>

namespace tvm {
namespace tir {

/*!
* \brief Visitor class to perform use/def analysis, also delete unreferenced lets.
* \param defined_vars Variables that have been defined.
* \param visit_thread_extent Whether enters thread extent expressions or not.
* \sa UndefinedVars
*/
class VarUseDefAnalyzer : public StmtExprVisitor {
public:
explicit VarUseDefAnalyzer(const Array<Var>& defined_vars, bool visit_thread_extent = true);
// The fields are publically readible to
// be accessible to the users.
bool visit_thread_extent_{true};
Array<Var> undefined_;

std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;

private:
ExprDeepEqual deep_equal_;
std::unordered_map<const VarNode*, const LetNode*> let_binding_;
void VisitStmt_(const AttrStmtNode* op) final;

void VisitStmt_(const LetStmtNode* op) final;

void VisitStmt_(const ForNode* op) final;

void VisitStmt_(const AllocateNode* op) final;

void VisitStmt_(const AllocateConstNode* op) final;

void VisitStmt_(const StoreNode* op) final;

void VisitStmt_(const BufferStoreNode* op) final;

void VisitExpr_(const LetNode* op) final;

void VisitExpr_(const VarNode* op) final;

void VisitExpr_(const ReduceNode* op) final;

void VisitExpr_(const LoadNode* op) final;

void VisitExpr_(const BufferLoadNode* op) final;

void HandleDef(const VarNode* v);

void HandleUse(const VarNode* v);

void VisitBuffer(Buffer buffer);
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_
Loading

0 comments on commit 082c443

Please sign in to comment.