Skip to content

Commit

Permalink
Fix MXPredReshape in the c_predict_api (apache#11493)
Browse files Browse the repository at this point in the history
* Fix MXPredReshape in the c_predict_api.

* Add unittest for the C predict API.

* Fix path in the test.

* Fix for Windows.

* Try again to fix for Windows.

* One more try to fix test on Windows.

* Try again with CI.

* Try importing from mxnet first if cannot find the amalgamation lib.

* Add a log message when libmxnet_predict.so is not found.

* Set specific rtol and atol values.

* Fix missing rtol and atol values.

* Empty commit.

* Try again with CI.

* One more try with CI.

* Retry CI.
  • Loading branch information
hqucms authored and leezu committed Aug 16, 2018
1 parent 85179e9 commit 5832940
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 11 deletions.
62 changes: 53 additions & 9 deletions amalgamation/python/mxnet_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os
import sys
import ctypes
import logging
import numpy as np

__all__ = ["Predictor", "load_ndarray_file"]
Expand All @@ -51,15 +52,25 @@ def c_array(ctype, values):
def _find_lib_path():
"""Find mxnet library."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/')
dll_path = [curr_path, api_path]
dll_path = [os.path.join(p, 'libmxnet.so') for p in dll_path] + \
[os.path.join(p, 'libmxnet_predict.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path)))
return lib_path
amalgamation_lib_path = os.path.join(curr_path, '../../lib/libmxnet_predict.so')
if os.path.exists(amalgamation_lib_path) and os.path.isfile(amalgamation_lib_path):
lib_path = [amalgamation_lib_path]
return lib_path
else:
logging.info('Cannot find libmxnet_predict.so. Will search for MXNet library using libinfo.py then.')
try:
from mxnet.libinfo import find_lib_path
lib_path = find_lib_path()
return lib_path
except ImportError:
libinfo_path = os.path.join(curr_path, '../../python/mxnet/libinfo.py')
if os.path.exists(libinfo_path) and os.path.isfile(libinfo_path):
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
lib_path = libinfo['find_lib_path']()
return lib_path
else:
raise RuntimeError('Cannot find libinfo.py at %s.' % libinfo_path)


def _load_lib():
Expand Down Expand Up @@ -159,6 +170,39 @@ def forward(self, **kwargs):
mx_uint(v.size)))
_check_call(_LIB.MXPredForward(self.handle))

def reshape(self, input_shapes):
"""Change the input shape of the predictor.
Parameters
----------
input_shapes : dict of str to tuple
The new shape of input data.
Examples
--------
>>> predictor.reshape({'data':data_shape_tuple})
"""
indptr = [0]
sdata = []
keys = []
for k, v in input_shapes.items():
if not isinstance(v, tuple):
raise ValueError("Expect input_shapes to be dict str->tuple")
keys.append(c_str(k))
sdata.extend(v)
indptr.append(len(sdata))

new_handle = PredictorHandle()
_check_call(_LIB.MXPredReshape(
mx_uint(len(indptr) - 1),
c_array(ctypes.c_char_p, keys),
c_array(mx_uint, indptr),
c_array(mx_uint, sdata),
self.handle,
ctypes.byref(new_handle)))
_check_call(_LIB.MXPredFree(self.handle))
self.handle = new_handle

def get_output(self, index):
"""Get the index-th output.
Expand Down
5 changes: 3 additions & 2 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
}
sym = nnvm::Symbol::CreateGroup(out_syms);
}
ret->sym = sym;

// load the parameters
std::unordered_map<std::string, NDArray> arg_params, aux_params;
Expand Down Expand Up @@ -214,6 +215,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
}

Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
ret->ctx = ctx;

std::vector<NDArray> arg_arrays, aux_arrays;
for (size_t i = 0; i < arg_shapes.size(); ++i) {
Expand All @@ -231,6 +233,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
aux_arrays.push_back(nd);
}
ret->arg_arrays = arg_arrays;
ret->aux_arrays = aux_arrays;
// bind
{
std::map<std::string, Context> ctx_map;
Expand Down Expand Up @@ -309,7 +312,6 @@ int MXPredReshape(mx_uint num_input_nodes,
<< " shape has been changed, only allow to change the shape of input data.";
}
}
p->arg_arrays.clear();

for (size_t i=0; i < aux_names.size(); ++i) {
TShape newShape = aux_shapes[i];
Expand All @@ -319,7 +321,6 @@ int MXPredReshape(mx_uint num_input_nodes,
<< " shape has been changed, only allow to change the shape of input data.";
}
ret->aux_arrays = p->aux_arrays;
p->aux_arrays.clear();

// bind
{
Expand Down
87 changes: 87 additions & 0 deletions tests/python/unittest/test_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import sys, os
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../../amalgamation/python/"))
from mxnet_predict import Predictor, load_ndarray_file

import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon
from mxnet.test_utils import assert_almost_equal
from common import setup_module, with_seed, teardown

@with_seed()
def test_predictor():
prefix = 'test_predictor_simple_dense'
symbol_file = "%s-symbol.json" % prefix
param_file = "%s-0000.params" % prefix

# two inputs with different batch sizes
input1 = np.random.uniform(size=(1,3))
input2 = np.random.uniform(size=(3,3))

# define a simple model
block = gluon.nn.HybridSequential()
block.add(gluon.nn.Dense(7))
block.add(gluon.nn.Dense(3))
block.hybridize()
block.initialize()
out1 = block.forward(nd.array(input1))
out2 = block.forward(nd.array(input2))
block.export(prefix)

# create a predictor
predictor = Predictor(open(symbol_file, "r").read(),
open(param_file, "rb").read(),
{'data':input1.shape})

# forward and get output
predictor.forward(data=input1)
predictor_out1 = predictor.get_output(0)
assert_almost_equal(out1.asnumpy(), predictor_out1, rtol=1e-5, atol=1e-6)

# reshape
predictor.reshape({'data':input2.shape})
predictor.forward(data=input2)
predictor_out2 = predictor.get_output(0)
assert_almost_equal(out2.asnumpy(), predictor_out2, rtol=1e-5, atol=1e-6)

# destroy the predictor
del predictor

@with_seed()
def test_load_ndarray():
nd_file = 'test_predictor_load_ndarray.params'
a = nd.random.uniform(shape=(7, 3))
b = nd.random.uniform(shape=(7,))
nd_data = {'a':a, 'b':b}
nd.save(nd_file, nd_data)

# test load_ndarray_file
nd_load = load_ndarray_file(open(nd_file, "rb").read())
assert(set(nd_data.keys()) == set(nd_load.keys()))
for k in nd_data.keys():
assert_almost_equal(nd_data[k].asnumpy(), nd_load[k], rtol=1e-5, atol=1e-6)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 5832940

Please sign in to comment.