Skip to content

Commit

Permalink
Importer fixes and additional testing for end to end resnet support (a…
Browse files Browse the repository at this point in the history
…pache#31)

Some fixes to enable end to end resnet.
  • Loading branch information
Josh Fromm committed Feb 28, 2023
1 parent da1ad37 commit e49d040
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 5 deletions.
61 changes: 56 additions & 5 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,58 @@ def _impl_v6(cls, bb, inputs, attr):
return out


class BatchNormalization(OnnxOpConverter):
"""Converts an onnx BatchNormalization node into an equivalent Relax expression."""

@classmethod
def _impl_v16(cls, bb, inputs, attr):
# Unpack inputs
data = inputs[0]
scale = inputs[1]
bias = inputs[2]
mean = inputs[3]
var = inputs[4]
epsilon = attr.get("epsilon", 1e-05)
return relax.op.nn.batch_norm(data, scale, bias, mean, var, axis=1, epsilon=epsilon)


class MaxPool(OnnxOpConverter):
"""Converts an onnx MaxPool node into an equivalent Relax expression."""

@classmethod
def _impl_v12(cls, bb, inputs, attr):
# Unpack inputs and attributes.
data = inputs[0]
auto_pad = attr.get("auto_pad", b"NOTSET").decode("utf-8")
ceil_mode = attr.get("ceil_mode", 0)
dilations = attr.get("dilations", [1, 1])
kernel_shape = attr.get("kernel_shape")
pads = attr.get("pads", 0)
strides = attr.get("strides", 1)
if auto_pad != "NOTSET":
raise NotImplementedError("Auto padding not yet supported.")
return relax.op.nn.max_pool2d(data, kernel_shape, strides, pads, dilations, ceil_mode)


class GlobalAveragePool(OnnxOpConverter):
"""Converts an onnx GlobalAveragePool node into an equivalent Relax expression."""

@classmethod
def _impl_v1(cls, bb, inputs, attr):
return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1)


class Flatten(OnnxOpConverter):
"""Converts an onnx Flatten node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
axis = attr.get("axis", 1)
data_shape = [i.value for i in inputs[0].struct_info.shape]
new_shape = (1, -1) if axis == 0 else (_np.prod(data_shape[0:axis]).astype("int64"), -1)
return relax.op.reshape(inputs[0], new_shape)


def _get_convert_map():
return {
"MatMul": relay.frontend.onnx.MatMul,
Expand Down Expand Up @@ -1136,10 +1188,10 @@ def _get_convert_map():
"Pad": Pad,
"Split": Split,
"Tile": Tile,
"BatchNormalization": relay.frontend.onnx.BatchNorm,
"GlobalAveragePool": relay.frontend.onnx.GlobalAveragePool,
"Flatten": relay.frontend.onnx.Flatten,
"MaxPool": relay.frontend.onnx.MaxPool,
"BatchNormalization": BatchNormalization,
"GlobalAveragePool": GlobalAveragePool,
"Flatten": Flatten,
"MaxPool": MaxPool,
"Identity": Identity,
"Resize": Resize,
"Einsum": Einsum,
Expand Down Expand Up @@ -1325,7 +1377,6 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(outputs)

print(op_name, node.name)
op = self._convert_operator(op_name, inputs, attr, self.opset)
# Create struct information for the new operator.
op = self.bb.normalize(op)
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relax/frontend/test_onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,5 +1432,51 @@ def test_less_equal():
verify_compare("LessOrEqual", [32, 32])


def test_batch_norm():
batch_norm_node = helper.make_node(
"BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"], epsilon=1e-2
)
graph = helper.make_graph(
[batch_norm_node],
"batch_norm_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4, 5]),
helper.make_tensor_value_info("s", TensorProto.FLOAT, [3]),
helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3]),
helper.make_tensor_value_info("mean", TensorProto.FLOAT, [3]),
helper.make_tensor_value_info("var", TensorProto.FLOAT, [3]),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 4, 5])],
)

model = helper.make_model(graph, producer_name="batch_norm_test")
check_correctness(model)


def test_max_pool():
max_pool_node = helper.make_node("MaxPool", ["x"], ["y"], kernel_shape=[2, 2])
graph = helper.make_graph(
[max_pool_node],
"max_pool_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 32, 32]),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 16, 16])],
)

model = helper.make_model(graph, producer_name="max_pool_test")
check_correctness(model)


def test_global_average_pool():
verify_unary("GlobalAveragePool", [1, 3, 32, 32])


def test_flatten():
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2})


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit e49d040

Please sign in to comment.