Skip to content

Commit

Permalink
Dynamic ONNX Importer (apache#6351)
Browse files Browse the repository at this point in the history
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <[email protected]>
2 people authored and Tushar Dey committed Oct 15, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent fa3836c commit bbdb1f0
Showing 21 changed files with 957 additions and 489 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
@@ -208,6 +208,17 @@ TVM_DLL Pass SimplifyInference();
*/
TVM_DLL Pass FastMath();

/*!
* \brief Find Dynamic ops and make them static
*
* Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
* them with static ops and re-performs type inference and constant folding. The pass repeats
* itself until the graph stops changing or we run too many iterations.
*
* \return The pass.
*/
TVM_DLL Pass DynamicToStatic();

/*!
* \brief Infer the type of an expression.
*
11 changes: 8 additions & 3 deletions include/tvm/topi/broadcast.h
Original file line number Diff line number Diff line change
@@ -54,14 +54,19 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
<< "\nvs\ninput: " << t;
auto bh = detail::BroadcastShape(output_shape, t->shape);
CHECK_EQ(output_shape.size(), bh.common_shape.size());
Array<PrimExpr> oshape;
for (size_t i = 0; i < output_shape.size(); ++i) {
CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
oshape.push_back(output_shape[i]);
} else {
CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
oshape.push_back(bh.common_shape[i]);
}
}
auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
};
return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
l, name, tag);
return tvm::te::compute(oshape, l, name, tag);
}

#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
607 changes: 315 additions & 292 deletions python/tvm/relay/frontend/onnx.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
@@ -241,6 +241,7 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("multiply", False, broadcast_shape_func)
register_shape_func("divide", False, broadcast_shape_func)
register_shape_func("floor_divide", False, broadcast_shape_func)
register_shape_func("power", False, broadcast_shape_func)
register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
51 changes: 32 additions & 19 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
@@ -722,29 +722,18 @@ def compute_space_to_depth(attrs, inputs, out_dtype):


@script
def _conv2d_shape_func(dshape, kshape, strides, padding, dilation):
def _conv_shape_func(dshape, kshape, strides, padding, dilation):
out = output_tensor((dshape.shape[0],), "int64")
height = dshape[2]
width = dshape[3]
kheight = kshape[2]
kwidth = kshape[3]
dilated_kh = (kheight - 1) * dilation[0] + 1
dilated_kw = (kwidth - 1) * dilation[1] + 1

oc = kshape[0]

out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1

out[0] = dshape[0]
out[1] = oc
out[2] = out_height
out[3] = out_width
out[1] = kshape[0]

for i in const_range(dshape.shape[0] - 2):
dilated_k = (kshape[i + 2] - 1) * dilation[i] + 1
out[i + 2] = (dshape[i + 2] + 2 * padding[i] - dilated_k) // strides[i] + 1
return out


@reg.register_shape_func("nn.conv2d", False)
def conv2d_shape_func(attrs, inputs, _):
def conv_shape_func(attrs, inputs, _):
"""
Shape function for contrib_conv2d_NCHWc op.
"""
@@ -753,7 +742,7 @@ def conv2d_shape_func(attrs, inputs, _):
dilation = get_const_tuple(attrs.dilation)

return [
_conv2d_shape_func(
_conv_shape_func(
inputs[0],
inputs[1],
convert(strides),
@@ -763,6 +752,11 @@ def conv2d_shape_func(attrs, inputs, _):
]


reg.register_shape_func("nn.conv1d", False, conv_shape_func)
reg.register_shape_func("nn.conv2d", False, conv_shape_func)
reg.register_shape_func("nn.conv3d", False, conv_shape_func)


@script
def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
out = output_tensor((dshape.shape[0],), "int64")
@@ -968,6 +962,25 @@ def dense_shape_func(attrs, inputs, _):
return ret


@script
def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]

return out


@reg.register_shape_func("nn.batch_matmul", False)
def batch_matmul_shape_func(attrs, inputs, _):
"""
Shape function for dense op.
"""
ret = [_batch_matmul_shape_func(inputs[0], inputs[1])]
return ret


@script
def _pad_shape_func(data_shape, pad_width):
out = output_tensor((data_shape.shape[0],), "int64")
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
@@ -683,7 +683,7 @@ def wrap_compute_batch_matmul(topi_compute):
"""wrap batch_matmul topi compute"""

def _compute_batch_matmul(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1])]
return [topi_compute(inputs[0], inputs[1], out_type.shape)]

return _compute_batch_matmul

21 changes: 15 additions & 6 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import re
from tvm import topi
from tvm.te import SpecializedCondition
from tvm.relay.ty import is_dynamic
from .generic import *
from .. import op as _op

@@ -355,12 +356,20 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
"""batch_matmul x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul),
wrap_topi_schedule(topi.x86.schedule_batch_matmul),
name="batch_matmul.x86",
plevel=10,
)
if is_dynamic(out_type):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
name="batch_matmul.generic",
plevel=10,
)
else:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul),
wrap_topi_schedule(topi.x86.schedule_batch_matmul),
name="batch_matmul.x86",
plevel=10,
)
if "cblas" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@


@autotvm.register_topi_compute("batch_matmul.cuda")
def batch_matmul(cfg, x, y):
def batch_matmul(cfg, x, y, out_shape=None):
"""Compute conv2d with NCHW layout"""
return nn.batch_matmul(x, y)

25 changes: 15 additions & 10 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from ..util import get_const_tuple


def batch_matmul(x, y):
def batch_matmul(x, y, oshape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
@@ -37,14 +37,19 @@ def batch_matmul(x, y):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch, M, K = x.shape
N = y.shape[1]
k = te.reduce_axis((0, K), name="k")
if oshape is None:
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch, M, K = x.shape
N = y.shape[1]
k = te.reduce_axis((0, K), name="k")
oshape = (batch, M, N)
else:
_, _, K = x.shape
k = te.reduce_axis((0, K), name="k")
return te.compute(
(batch, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
)
6 changes: 5 additions & 1 deletion python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@


@autotvm.register_topi_compute("batch_matmul.x86")
def batch_matmul(cfg, x, y):
def batch_matmul(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
@@ -49,6 +49,10 @@ def batch_matmul(cfg, x, y):
assert XK == YK, "shapes of x and y is inconsistant"
B = XB
K = XK
if out_shape is not None:
assert out_shape[0] == B, "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
if cfg.is_fallback:
_default_batch_matmul_config(cfg, M, N, K)

3 changes: 3 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
@@ -263,6 +263,9 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::Legalize());
}

// Convert Dynamic ops to static versions
pass_seqs.push_back(transform::DynamicToStatic());

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
18 changes: 15 additions & 3 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
@@ -58,6 +58,11 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

Array<IndexExpr> oshape;
const auto* newshape = types[1].as<TensorTypeNode>();
if (newshape == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "reshape: expect input type to be TensorType but get " << types[1];
return false;
}

// Doesn't support dynamic output rank
for (int i = 0; i < newshape->shape[0].as<IntImmNode>()->value; i++) {
@@ -209,10 +214,17 @@ bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
// types = [data_type, broadcast_shape_type, ret_type]
CHECK_EQ(types.size(), 3);

const auto* target_shape = types[1].as<TensorTypeNode>();
DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
const auto* input_type = types[0].as<TensorTypeNode>();
const auto* target_type = types[1].as<TensorTypeNode>();
if (target_type == nullptr) {
return false;
}
if (input_type == nullptr) {
return false;
}
auto out_dtype = input_type->dtype;
// rank must be static
const IntImmNode* rank = target_shape->shape[0].as<IntImmNode>();
const IntImmNode* rank = target_type->shape[0].as<IntImmNode>();
CHECK(rank) << "Target shape must have static rank"; // rank must be static even in dyn pass
// could add support for dyn rank in futures

77 changes: 58 additions & 19 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
@@ -100,7 +100,9 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1]));
if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1]));
}
channels = wshape[0];
dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0];
}
@@ -211,7 +213,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
}
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -322,7 +326,9 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Conv3D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1]));
if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1]));
}
channels = wshape[0];
dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -800,16 +806,22 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
}
channels = wshape[1];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
}
// dilation
IndexExpr pad_w;
GetPaddingWidth(param->padding, &pad_w);
Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w +
param->output_padding[0]));
if (!dshape_ncw[2].as<tir::AnyNode>()) {
oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w +
param->output_padding[0]));
} else {
oshape.Set(2, dshape_ncw[2]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
@@ -890,7 +902,9 @@ bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< "Conv3D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0]));
if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0]));
}
channels = wshape[1];
dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -901,12 +915,25 @@ bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d +
param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h +
param->output_padding[1]));
oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w +
param->output_padding[2]));

if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d +
param->output_padding[0]));
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
if (!dshape_ncdhw[3].as<tir::AnyNode>()) {
oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h +
param->output_padding[1]));
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w +
param->output_padding[2]));
} else {
oshape.Set(4, dshape_ncdhw[4]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
@@ -985,7 +1012,9 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
}
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -994,10 +1023,18 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h +
param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w +
param->output_padding[1]));
if (!dshape_nchw[2].as<tir::AnyNode>()) {
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h +
param->output_padding[0]));
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w +
param->output_padding[1]));
} else {
oshape.Set(3, dshape_nchw[3]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
@@ -1053,7 +1090,9 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
if (!data->shape[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
}
channels = wshape[0];
ksize_y = wshape[2];
ksize_x = wshape[3];
53 changes: 38 additions & 15 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
@@ -851,15 +851,26 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
CHECK(x->shape.size() == 3 && y->shape.size() == 3);
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;

Array<tvm::PrimExpr> oshape = x->shape;
oshape.Set(2, y->shape[1]);
bool is_dyn = false;
Array<tvm::PrimExpr> oshape;
for (size_t i = 0; i < 3; ++i) {
if (x->shape[i].as<tir::AnyNode>() != nullptr || y->shape[i].as<tir::AnyNode>() != nullptr) {
is_dyn = true;
oshape.push_back(Any());
} else {
oshape.push_back(x->shape[i]);
}
}
if (!is_dyn) {
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;

oshape.Set(2, y->shape[1]);
}

// assign output type
reporter->Assign(types[2], TensorType(oshape, x->dtype));
@@ -1021,9 +1032,15 @@ bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
<< " But got " << in_layout;

auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(1, indexdiv(oshape[1], (block_size * block_size)));
oshape.Set(2, oshape[2] * block_size);
oshape.Set(3, oshape[3] * block_size);
if (!oshape[1].as<tir::AnyNode>()) {
oshape.Set(1, indexdiv(oshape[1], (block_size * block_size)));
}
if (!oshape[2].as<tir::AnyNode>()) {
oshape.Set(2, oshape[2] * block_size);
}
if (!oshape[3].as<tir::AnyNode>()) {
oshape.Set(3, oshape[3] * block_size);
}

// Assign output type
reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
@@ -1078,9 +1095,15 @@ bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attr
<< " But got " << in_layout;

auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(1, oshape[1] * (block_size * block_size));
oshape.Set(2, indexdiv(oshape[2], block_size));
oshape.Set(3, indexdiv(oshape[3], block_size));
if (!oshape[1].as<tir::AnyNode>()) {
oshape.Set(1, oshape[1] * (block_size * block_size));
}
if (!oshape[2].as<tir::AnyNode>()) {
oshape.Set(2, indexdiv(oshape[2], block_size));
}
if (!oshape[3].as<tir::AnyNode>()) {
oshape.Set(3, indexdiv(oshape[3], block_size));
}

// Assign output type
reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
8 changes: 5 additions & 3 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
@@ -63,9 +63,11 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (weight == nullptr) return false;
Array<tvm::PrimExpr> wshape = weight->shape;
CHECK(static_cast<int>(weight->shape.size()) == 2);
CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
<< "DenseRel: input dimension doesn't match,"
<< " data shape=" << data->shape << ", weight shape=" << weight->shape;
if (!data->shape.back().as<tir::AnyNode>()) {
CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
<< "DenseRel: input dimension doesn't match,"
<< " data shape=" << data->shape << ", weight shape=" << weight->shape;
}
oshape.Set((oshape.size() - 1), wshape[0]);
}

10 changes: 6 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
@@ -1822,9 +1822,9 @@ bool SqueezeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (p.second) {
result_shape.push_back(p.first);
} else {
const int64_t* axis_ptr = tir::as_const_int(p.first);
CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor";
CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
if (const int64_t* axis_ptr = tir::as_const_int(p.first)) {
CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
}
}
}
}
@@ -2028,7 +2028,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
if (param == nullptr) {
return false;
}
const auto* data = types[0].as<TensorTypeNode>();

if (data == nullptr) {
9 changes: 6 additions & 3 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
@@ -227,6 +227,9 @@ Expr DynamicToStatic(Function f, IRModule m) {
vars.Set(kv.second, kv.first);
}
const auto gv = vars[f];
// Put a limit on the while loop
// Primarily used to prevent accidental infinite lops in development
const int loop_limit = 1000;
int i = 0;
do {
pre = expr;
@@ -236,13 +239,13 @@ Expr DynamicToStatic(Function f, IRModule m) {
expr = mutator.Mutate(m->functions[gv]);
m->Update(gv, Downcast<BaseFunc>(expr));
i += 1;
} while (pre != expr && i < 1000);
} while (!StructuralEqual()(pre, expr) && i < loop_limit);
return expr;
}

namespace transform {

Pass ConvertDynamicToStatic() {
Pass DynamicToStatic() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(DynamicToStatic(f, m));
@@ -251,7 +254,7 @@ Pass ConvertDynamicToStatic() {
}

TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() {
return ConvertDynamicToStatic();
return DynamicToStatic();
});

} // namespace transform
413 changes: 331 additions & 82 deletions tests/python/frontend/onnx/test_forward.py

Large diffs are not rendered by default.

82 changes: 55 additions & 27 deletions tests/python/relay/dyn/test_dynamic_op_level10.py
Original file line number Diff line number Diff line change
@@ -27,34 +27,62 @@
import random
import tvm.testing

# TODO(mbrookhart): Enable when VM supports heterogenus execution
# TODO(mbrookhart): Enable when the VM supports heterogenus execution
# @tvm.testing.uses_gpu
def test_dyn_broadcast_to():
dtype = "uint8"
rank = 3
shape_type = "int64"
dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
x_shape = (1,)
x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
z = relay.broadcast_to(x, dyn_shape)
zz = run_infer_type(z)

assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)

func = relay.Function([x, dyn_shape], z)

x = np.random.uniform(size=x_shape).astype(dtype)
dyn_shape = (1,) * rank
ref_res = np.broadcast_to(x, dyn_shape)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)


# TODO(mbrookhart): Enable when VM supports heterogenus execution
def test_broadcast_to():
def verify_more_dynamic_broadcast_to(x_shape, out_shape):
rank = len(out_shape)
dtype = "float32"
shape_type = "int64"
reshape_shape = relay.Var("shape", relay.ty.TensorType((len(x_shape),), shape_type))
broadcast_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
x = relay.Var("x", relay.ty.TensorType((np.prod(x_shape),), dtype))
r = relay.reshape(x, reshape_shape)
z = relay.broadcast_to(r, broadcast_shape)

func = relay.Function([x, reshape_shape, broadcast_shape], z)

x = np.random.uniform(size=np.prod(x_shape)).astype(dtype)
ref_res = np.broadcast_to(np.reshape(x, x_shape), out_shape)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(
x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type)
)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3))

def verify_broadcast_to(x_shape, out_shape):
rank = len(out_shape)
dtype = "float32"
shape_type = "int64"
dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
z = relay.broadcast_to(x, dyn_shape)
zz = run_infer_type(z)

assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)

func = relay.Function([x, dyn_shape], z)

x = np.random.uniform(size=x_shape).astype(dtype)
ref_res = np.broadcast_to(x, out_shape)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x, np.array(out_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_broadcast_to((1,), (1, 1, 1))
verify_broadcast_to((1, 1), (4, 1, 1))
verify_broadcast_to((4, 1), (1, 4, 3))


# TODO(mbrookhart): Enable when the VM supports heterogenus execution
# @tvm.testing.uses_gpu
def test_dyn_one_hot():
def _get_oshape(indices_shape, depth, axis):
27 changes: 27 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
@@ -362,6 +362,33 @@ def test_batch_matmul():
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))


def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
x = relay.var("x", relay.TensorType(x_shape, dtype))
y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype))
z = relay.nn.batch_matmul(x, y)

func = relay.Function([x, y], z)
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
z_np = tvm.topi.testing.batch_matmul(x_np, y_np)

for target, ctx in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
z = intrp.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
def test_dynamic_batch_matmul():
verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))


@tvm.testing.uses_gpu
def test_shape_of():
shape = (10, 5, 12)
9 changes: 9 additions & 0 deletions tutorials/frontend/from_onnx.py
Original file line number Diff line number Diff line change
@@ -103,3 +103,12 @@
canvas[:, 672:, :] = np.asarray(result)
plt.imshow(canvas.astype(np.uint8))
plt.show()

######################################################################
# Notes
# ---------------------------------------------
# By default, ONNX defines models in terms of dynamic shapes. The ONNX importer
# retains that dynamism upon import, and the compiler attemps to convert the model
# into a static shapes at compile time. If this fails, there may still be dynamic
# operations in the model. Not all TVM kernels currently support dynamic shapes,
# please file an issue on discuss.tvm.ai if you hit an error with dynamic kernels.

0 comments on commit bbdb1f0

Please sign in to comment.