Skip to content

Commit

Permalink
[Relay][Frontend] CoreML Support
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Jan 21, 2019
1 parent 0806b69 commit 1f55a8b
Show file tree
Hide file tree
Showing 11 changed files with 836 additions and 17 deletions.
9 changes: 5 additions & 4 deletions nnvm/python/nnvm/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ def ConvolutionLayerParams(op, insym, symtab):
else:
pos = [insym, weights]

if op.isDeconvolution:
ret = _sym.conv2d_transpose(*pos, **params)
else:
ret = _sym.conv2d(*pos, **params)
# consume padding layer
if symtab.in_padding:
params['padding'] = [sum(x) for x in zip(params.get('padding', [0, 0]), symtab.paddings)]
symtab.clear_padding()

if op.isDeconvolution:
ret = _sym.conv2d_transpose(*pos, **params)
else:
ret = _sym.conv2d(*pos, **params)
return ret

def BatchnormLayerParams(op, insym, symtab):
Expand Down
4 changes: 2 additions & 2 deletions nnvm/tests/python/frontend/coreml/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import urllib
from six.moves import urllib
import os
from PIL import Image
import numpy as np
Expand All @@ -7,7 +7,7 @@ def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.URLopener().retrieve(url, path)
urllib.request.urlretrieve(url, path)

def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
Expand Down
8 changes: 4 additions & 4 deletions nnvm/tests/python/frontend/coreml/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import model_zoo

def get_tvm_output(symbol, x, params, target, ctx,
out_shape=(1000,), input_name='image', dtype='float32'):
out_shape=(1, 1000), input_name='image', dtype='float32'):
shape_dict = {input_name : x.shape}
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(symbol, target, shape_dict, params=params)
Expand All @@ -28,7 +28,7 @@ def get_tvm_output(symbol, x, params, target, ctx,
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()

def test_model_checkonly(model_file, model_name=''):
def run_model_checkonly(model_file, model_name=''):
model = cm.models.MLModel(model_file)
sym, params = nnvm.frontend.from_coreml(model)
x = model_zoo.get_cat_image()
Expand All @@ -38,11 +38,11 @@ def test_model_checkonly(model_file, model_name=''):

def test_mobilenet_checkonly():
model_file = model_zoo.get_mobilenet()
test_model_checkonly(model_file, 'mobilenet')
run_model_checkonly(model_file, 'mobilenet')

def test_resnet50_checkonly():
model_file = model_zoo.get_resnet50()
test_model_checkonly(model_file, 'resnet50')
run_model_checkonly(model_file, 'resnet50')

def run_tvm_graph(graph_def, input_data, input_name, output_shape, output_dtype='float32'):
""" Generic function to compile on nnvm and execute on tvm """
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(self,
_make.Function, params, body, ret_type, type_params, attrs)

def __call__(self, *args):
"""Invoke the gobal function.
"""Invoke the global function.
Parameters
----------
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .mxnet import from_mxnet
from .keras import from_keras
from .onnx import from_onnx
from .coreml import from_coreml
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(self):
self.exprs = {}
self.params = {}
self.const_ctr = 1
self.in_padding = False

def new_const(self, value, shape=None, dtype="float32"):
name = "_param_%d" % (self.const_ctr)
Expand All @@ -257,6 +258,13 @@ def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr

def set_padding(self, paddings):
self.paddings = paddings
self.in_padding = True

def clear_padding(self):
self.in_padding = False


class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
Expand Down
Loading

0 comments on commit 1f55a8b

Please sign in to comment.