Skip to content

Commit

Permalink
[Analysis] Expose analyses related to vars in Python (tlc-pack#265)
Browse files Browse the repository at this point in the history
Previously, analyses to gather up all variables, free variables, bound variables, all global variables, and all global variables that are called had been implemented in C++ but had not been exposed in Python or tested. This PR exposes these analyses and adds tests for them.

Two further changes:
* The analyses previously ignored variables bound in `MatchShape` nodes; these are now treated as bindings too.
* `rec_global_vars` is renamed `called_global_vars`, since the analysis itself does not check recursion.
  • Loading branch information
slyubomirsky authored and junrushao committed Jan 26, 2023
1 parent 0fecef1 commit 6449bec
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 11 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);

/*!
* \brief Get all glabal variables for recursive call from expression expr.
* \brief Get all glabal variables used in calls in expression expr.
*
* \param expr the expression.
*
* \return List of all global variables for recursive call.
* \return List of all global variables called in expr.
*/
TVM_DLL tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr);
TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);

/*!
* \brief Get all glabal variables from expression expr.
Expand Down
94 changes: 93 additions & 1 deletion python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import tvm
from tvm import tir
from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Binding
from tvm.relax.expr import DataflowBlock, GlobalVar, Var, Expr, Function, Binding
from . import _ffi_api


Expand Down Expand Up @@ -157,3 +157,95 @@ def derive_func_ret_shape(args: List[Var], body: Expr) -> Expr:
An expression that can serve as the return shape for the function
"""
return _ffi_api.derive_func_ret_shape(args, body)


def bound_vars(expr: Expr) -> List[Var]:
"""
Return all bound variables from expression expr.
Bound variables are all variables that are declared in the expr.
They only have meaning inside that expr, and can only be used in it.
Parameters
----------
expr: Expr
The expression.
Returns
-------
ret: List[Var]
List of bound vars in expr, in post-DFS order
"""
return _ffi_api.bound_vars(expr)


def free_vars(expr: Expr) -> List[Var]:
"""
Return all free variables from expression expr.
Free variables are variables that are not bound by a
VarBinding or a function parameter in the expression.
Parameters
----------
expr: Expr
The expression.
Returns
-------
ret: List[Var]
List of free vars in expr, in post-DFS order
"""
return _ffi_api.free_vars(expr)


def all_vars(expr: Expr) -> List[Var]:
"""
Return all (local) variables from expression expr.
Parameters
----------
expr: Expr
The expression.
Returns
-------
ret: List[Var]
List of vars in expr, in post-DFS order
"""
return _ffi_api.all_vars(expr)


def all_global_vars(expr: Expr) -> List[GlobalVar]:
"""
Return all global variables from expression expr.
Parameters
----------
expr: Expr
The expression.
Returns
-------
ret: List[GlobalVar]
List of global vars in expr, in post-DFS order
"""
return _ffi_api.all_global_vars(expr)


def called_global_vars(expr: Expr) -> List[GlobalVar]:
"""
Return all global vars called (potentially recursively) from expr.
Parameters
----------
expr: Expr
The expression
Returns
-------
ret: List[GlobalVar]
List of global vars that are used recursively in expr,
in post-DFS order
"""
return _ffi_api.called_global_vars(expr)
21 changes: 15 additions & 6 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ class VarVisitor : protected ExprVisitor {
return ret;
}

Array<GlobalVar> RecGlobalVars(const Expr& expr) {
Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
this->VisitExpr(expr);
Array<GlobalVar> ret;
for (const auto& v : rec_global_vars_.data) {
for (const auto& v : called_global_vars_.data) {
ret.push_back(v);
}
return ret;
Expand Down Expand Up @@ -128,7 +128,7 @@ class VarVisitor : protected ExprVisitor {
}

if (const GlobalVarNode* global_var_node = call_node->op.as<GlobalVarNode>()) {
rec_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
}
}

Expand All @@ -138,11 +138,18 @@ class VarVisitor : protected ExprVisitor {
VisitVarDef(binding->var);
}

void VisitBinding_(const MatchShapeNode* binding) final {
if (binding->var.defined()) {
MarkBounded(binding->var);
}
ExprVisitor::VisitBinding_(binding);
}

private:
InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_;
InsertionSet<GlobalVar> global_vars_;
InsertionSet<GlobalVar> rec_global_vars_;
InsertionSet<GlobalVar> called_global_vars_;
};

class DimVisitor : public tir::ExprVisitor {
Expand Down Expand Up @@ -185,7 +192,9 @@ tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }

tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }

tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr) { return VarVisitor().RecGlobalVars(expr); }
tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
return VarVisitor().CalledGlobalVars(expr);
}

TVM_REGISTER_GLOBAL("relax.analysis.shape_vars").set_body_typed(ShapeVars);

Expand All @@ -197,7 +206,7 @@ TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);

TVM_REGISTER_GLOBAL("relax.analysis.rec_global_vars").set_body_typed(RecGlobalVars);
TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);

} // namespace relax
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relax/transform/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class LambdaLifter : public ExprMutator {
String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++);
auto global = GlobalVar(lift_func_name);
Array<Var> captured_vars = FreeVars(func);
recur_vars_ = RecGlobalVars(func);
recur_vars_ = CalledGlobalVars(func);
auto all_global_vars = AllGlobalVars(func);

Array<Var> typed_captured_vars;
Expand Down
118 changes: 118 additions & 0 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from __future__ import annotations
from typing import List, Set, Union
import pytest

import tvm
Expand All @@ -27,10 +28,19 @@
name_to_binding,
shape_vars,
derive_func_ret_shape,
all_vars,
free_vars,
bound_vars,
all_global_vars,
called_global_vars,
)
from tvm.script import relax as R


def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]:
return set(map(lambda v: v.name_hint, vars))


def test_dispatch_var():
m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
Expand Down Expand Up @@ -291,5 +301,113 @@ def test_derive_func_ret_shape_free():
assert isinstance(shape_expr, rx.RuntimeDepShape)


@tvm.script.ir_module
class VarExample:
@R.function
def func(a: Tensor) -> Tensor:
return R.add(a, a)

@R.function
def main(x: Tensor, y: Tensor) -> Tensor:
z = R.add(x, y)
# no binding here
R.match_shape(x, (5, 5))
with R.dataflow():
q = R.add(z, z)
p = func(q)
r = R.match_shape(p, (5, 5))
s = r
R.output(s)
return s


def test_all_vars():
vars = all_vars(VarExample["func"])
assert len(vars) == 1
assert vars[0].name_hint == "a"

var_names = var_name_set(all_vars(VarExample["main"]))
assert var_names == {"x", "y", "z", "p", "q", "r", "s"}


def test_bound_vars():
vars = bound_vars(VarExample["func"])
assert len(vars) == 1
assert vars[0].name_hint == "a"

# all the vars are bound
var_names = var_name_set(bound_vars(VarExample["main"]))
assert var_names == {"x", "y", "z", "p", "q", "r", "s"}

# if we consider only the body, then the function arguments are not bound
body_names = var_name_set(bound_vars(VarExample["main"].body))
assert body_names == {"z", "p", "q", "r", "s"}

# if the argument isn't bound, then nothing is
assert len(bound_vars(VarExample["func"].body)) == 0


def test_free_vars():
# all the vars are bound
assert len(free_vars(VarExample["func"])) == 0
assert len(free_vars(VarExample["main"])) == 0

# the arguments are free if we look only at the bodies
func_free = var_name_set(free_vars(VarExample["func"].body))
main_free = var_name_set(free_vars(VarExample["main"].body))
assert len(func_free) == 1
assert len(main_free) == 2
assert "a" in func_free
assert main_free == {"x", "y"}

# function that captures vars
x = rx.Var("x", type_annotation=rx.DynTensorType(ndim=-1))
y = rx.Var("y", type_annotation=rx.DynTensorType(ndim=-1))
z = rx.Var("z", type_annotation=rx.DynTensorType(ndim=-1))
inner = rx.Function(
[z],
rx.op.add(x, rx.op.add(y, z)),
ret_type=rx.DynTensorType(ndim=-1),
ret_shape=rx.RuntimeDepShape(),
)
outer = rx.Function(
[x, y],
rx.Call(inner, [y]),
ret_type=rx.DynTensorType(ndim=-1),
ret_shape=rx.RuntimeDepShape(),
)
assert len(free_vars(outer)) == 0
assert var_name_set(free_vars(inner)) == {"x", "y"}


def test_all_global_vars():
# there is one call to "func"
global_vars = all_global_vars(VarExample["main"])
assert len(global_vars) == 1
assert global_vars[0].name_hint == "func"

gv1 = rx.GlobalVar("gv1")
gv2 = rx.GlobalVar("gv2")
gv3 = rx.GlobalVar("gv3")
call = rx.Call(gv1, [gv2, gv3])
call_var_names = var_name_set(all_global_vars(call))
assert call_var_names == {"gv1", "gv2", "gv3"}


def test_called_global_vars():
# there is one call to "func"
global_vars = called_global_vars(VarExample["main"])
assert len(global_vars) == 1
assert global_vars[0].name_hint == "func"

gv1 = rx.GlobalVar("gv1")
gv2 = rx.GlobalVar("gv2")
gv3 = rx.GlobalVar("gv3")
call = rx.Call(gv1, [gv2, gv3])
call_vars = called_global_vars(call)
assert len(call_vars) == 1
assert call_vars[0].name_hint == "gv1"


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 6449bec

Please sign in to comment.