Skip to content

Commit

Permalink
Refactor Dynamic to Static (apache#7368)
Browse files Browse the repository at this point in the history
* DynamicToStatic Refactor

* fix test

* add regression tests

* cleanup

* skip PrepareInput if the arg is already a constant

* fix an issue with type inference with global functions
  • Loading branch information
Matthew Brookhart authored and alexwong committed Feb 11, 2021
1 parent 08c78f6 commit 2be6852
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 61 deletions.
155 changes: 96 additions & 59 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,30 @@ namespace relay {

class DynamicToStaticMutator : public MixedModeMutator {
public:
DynamicToStaticMutator() {
DynamicToStaticMutator(IRModule mod, Function func) : mod_(mod), func_(func) {
op_map_ = {
{Op::Get("dyn.reshape"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
ICHECK_EQ(shape->data->ndim, 1);
return MakeReshape(call_node->args[0], ToVector(shape->data));
}
return Expr(nullptr);
}},
{Op::Get("dyn.tile"),
[](const CallNode* call_node) {
if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* reps = args[1].as<ConstantNode>()) {
ICHECK_EQ(reps->data->ndim, 1);
return MakeTile(call_node->args[0], ToVector(reps->data));
}
return Expr(nullptr);
}},
{Op::Get("dyn.topk"),
[](const CallNode* call_node) {
if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* k = args[1].as<ConstantNode>()) {
const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
ICHECK(param);
return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data, 0)),
Expand All @@ -63,34 +66,38 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.broadcast_to"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
ICHECK_EQ(shape->data->ndim, 1);
return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
}
return Expr(nullptr);
}},
{Op::Get("dyn.zeros"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
ICHECK(param);
return MakeZeros(ToVector(shape->data), param->dtype);
}
return Expr(nullptr);
}},
{Op::Get("dyn.ones"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
ICHECK(param);
return MakeOnes(ToVector(shape->data), param->dtype);
}
return Expr(nullptr);
}},
{Op::Get("dyn.one_hot"),
[](const CallNode* call_node) {
if (const ConstantNode* depth = call_node->args[3].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* depth = args[3].as<ConstantNode>()) {
const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>();
ICHECK(param);
return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2],
Expand All @@ -100,8 +107,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.image.resize"),
[](const CallNode* call_node) {
if (const ConstantNode* size = call_node->args[1].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* size = args[1].as<ConstantNode>()) {
const ResizeAttrs* param = call_node->attrs.as<ResizeAttrs>();
ICHECK(param);
auto size_int = ToVector(size->data);
Expand All @@ -115,8 +123,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.full"),
[](const CallNode* call_node) {
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
ICHECK_EQ(shape->data->ndim, 1);
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
ICHECK(param);
Expand All @@ -125,9 +134,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.nn.upsampling"),
[](const CallNode* call_node) {
const ConstantNode* scale_h = call_node->args[1].as<ConstantNode>();
const ConstantNode* scale_w = call_node->args[2].as<ConstantNode>();
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
const ConstantNode* scale_h = args[1].as<ConstantNode>();
const ConstantNode* scale_w = args[2].as<ConstantNode>();
if (scale_h && scale_w) {
ICHECK_EQ(scale_h->data->ndim, 0);
ICHECK_EQ(scale_w->data->ndim, 0);
Expand All @@ -140,10 +150,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.nn.upsampling3d"),
[](const CallNode* call_node) {
const ConstantNode* scale_d = call_node->args[1].as<ConstantNode>();
const ConstantNode* scale_h = call_node->args[2].as<ConstantNode>();
const ConstantNode* scale_w = call_node->args[3].as<ConstantNode>();
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
const ConstantNode* scale_d = args[1].as<ConstantNode>();
const ConstantNode* scale_h = args[2].as<ConstantNode>();
const ConstantNode* scale_w = args[3].as<ConstantNode>();
if (scale_d && scale_h && scale_w) {
ICHECK_EQ(scale_d->data->ndim, 0);
ICHECK_EQ(scale_h->data->ndim, 0);
Expand All @@ -159,9 +170,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.nn.pad"),
[](const CallNode* call_node) {
const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
const ConstantNode* pad_fill = call_node->args[2].as<ConstantNode>();
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
const ConstantNode* pad_width = args[1].as<ConstantNode>();
const ConstantNode* pad_fill = args[2].as<ConstantNode>();
if (pad_width && pad_fill) {
ICHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d
ICHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d
Expand All @@ -174,10 +186,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.strided_slice"),
[](const CallNode* call_node) {
const ConstantNode* begin = call_node->args[1].as<ConstantNode>();
const ConstantNode* end = call_node->args[2].as<ConstantNode>();
const ConstantNode* stride = call_node->args[3].as<ConstantNode>();
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
const ConstantNode* begin = args[1].as<ConstantNode>();
const ConstantNode* end = args[2].as<ConstantNode>();
const ConstantNode* stride = args[3].as<ConstantNode>();
if (begin && end && stride) {
ICHECK_EQ(begin->data->ndim, 1);
ICHECK_EQ(end->data->ndim, 1);
Expand All @@ -190,8 +203,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
{Op::Get("dyn.sparse_to_dense"),
[](const CallNode* call_node) {
const ConstantNode* output_shape = call_node->args[3].as<ConstantNode>();
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
const ConstantNode* output_shape = args[3].as<ConstantNode>();
if (output_shape) {
ICHECK_EQ(output_shape->data->ndim, 1);
return MakeSparseToDense(call_node->args[0], ToVector(output_shape->data),
Expand All @@ -200,6 +214,45 @@ class DynamicToStaticMutator : public MixedModeMutator {
return Expr(nullptr);
}},
};
Map<BaseFunc, GlobalVar> vars;
for (auto kv : mod_->functions) {
vars.Set(kv.second, kv.first);
}
gv_ = vars[func_];
}

Expr PrepareInput(const Expr& expr) {
BaseFunc func;
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
func =
relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {});
}
mod_->Update(gv_, func);
mod_ = transform::FoldConstant()(mod_);
mod_ = transform::InferType()(mod_);
mod_ = transform::FoldConstant()(mod_);
mod_ = transform::InferType()(mod_);
Expr out;
if (expr.as<FunctionNode>()) {
out = mod_->Lookup(gv_);
} else {
out = mod_->Lookup(gv_).as<FunctionNode>()->body;
}
return out;
}

std::vector<Expr> PrepareArgs(const CallNode* call_node) {
std::vector<Expr> args;
for (auto arg : call_node->args) {
if (arg.as<ConstantNode>()) {
args.emplace_back(arg);
} else {
args.emplace_back(PrepareInput(arg));
}
}
return args;
}

private:
Expand All @@ -222,35 +275,19 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return post;
}

std::unordered_map<Expr, std::function<Expr(const CallNode*)>, ObjectPtrHash, ObjectPtrEqual>
op_map_;
IRModule mod_;
Function func_;
GlobalVar gv_;
};

Expr DynamicToStatic(Function f, IRModule m) {
Expr pre = f;
Expr expr = f;
auto fold_const = transform::FoldConstant();
auto infer_type = transform::InferType();
DynamicToStaticMutator mutator;
Map<BaseFunc, GlobalVar> vars;
for (auto kv : m->functions) {
vars.Set(kv.second, kv.first);
}
const auto gv = vars[f];
// Put a limit on the while loop
// Primarily used to prevent accidental infinite lops in development
const int loop_limit = 1000;
int i = 0;
do {
pre = expr;
// TODO(mbrookhart): Is it possible to run these passes JUST on the current function?
m = infer_type(m);
m = fold_const(m);
expr = mutator.Mutate(m->functions[gv]);
m->Update(gv, Downcast<BaseFunc>(expr));
i += 1;
} while (!StructuralEqual()(pre, expr) && i < loop_limit);
return expr;
DynamicToStaticMutator mutator(m, f);
Expr expr = mutator.Mutate(f);
Expr out = mutator.PrepareInput(expr);
return out;
}

namespace transform {
Expand Down
44 changes: 42 additions & 2 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ def verify_ones_zeros(shape, dtype):

func = run_infer_type(relay.Function([x], y))
func2 = run_opt_pass(
run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()
run_opt_pass(func, transform.DynamicToStatic()),
transform.InferType(),
)

zz = func2.body
assert isinstance(zz, relay.Constant)
assert zz.checked_type == relay.ty.TensorType(shape, dtype)

x_data = np.random.uniform(low=1, high=1, size=shape)
Expand Down Expand Up @@ -518,5 +518,45 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_
verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified


@tvm.testing.uses_gpu
def test_dynamic_to_static_dynamic_rank():
def verify_full(fill_value, fill_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
y = relay.var("y", relay.TensorType(fill_shape, "int64"))
shape = relay.shape_of(y)
shape = relay.strided_slice(shape, [0], relay.shape_of(shape))
z = relay.full(x, shape, dtype)

func = relay.Function([x, y], z)
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("full")

ref_res = np.full(fill_shape, fill_value).astype(dtype)
y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype("int64")
verify_func(func2, [fill_value, y_data], ref_res)

verify_full(4, (1, 2, 3, 4), "int32")
verify_full(4.0, (1, 2, 8, 10), "float32")


@tvm.testing.uses_gpu
def test_dynamic_to_static_dynamic_if():
x = relay.var("x", relay.TensorType((2, 2), "int64"))
cond = relay.const(1)
iff = relay.If(cond, relay.reshape(x, [1, 4]), relay.reshape(x, (4, 1)))

func = relay.Function([x], iff)
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("reshape")
x_data = np.random.uniform(low=-1, high=1, size=(2, 2)).astype("int64")
verify_func(func2, [x_data], x_data.reshape(1, 4))


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 2be6852

Please sign in to comment.