diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bb1f225ad13..aec5f21131d1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -699,7 +699,7 @@ else() endif() -add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/lib_api/mylib.cc) +add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) set(MXNET_INSTALL_TARGETS mxnet) if(UNIX) diff --git a/Makefile b/Makefile index d60c0fcbf1da..639f259487ab 100644 --- a/Makefile +++ b/Makefile @@ -662,6 +662,10 @@ cpplint: pylint: python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet +# sample lib for MXNet extension dynamically loading custom operator +sample_lib: + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o libsample_lib.so -I include/mxnet + # Cython build cython: cd python; $(PYTHON) setup.py build_ext --inplace --with-cython @@ -705,10 +709,6 @@ rpkgtest: Rscript -e 'require(testthat);res<-test_dir("R-package/tests/testthat");if(!testthat:::all_passed(res)){stop("Test failures", call. = FALSE)}' Rscript -e 'res<-covr:::package_coverage("R-package");fileConn<-file(paste("r-package_coverage_",toString(runif(1)),".json"));writeLines(covr:::to_codecov(res), fileConn);close(fileConn)' - -sample_lib: - $(CXX) -shared -fPIC example/lib_api/mylib.cc -o libsample_lib.so -I include/mxnet - scalaclean: (cd $(ROOTDIR)/scala-package && mvn clean) @@ -760,6 +760,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN) cd $(NNVM_PATH); $(MAKE) clean; cd - cd $(TVM_PATH); $(MAKE) clean; cd - cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - + $(RM) libsample_lib.so $(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS)) $(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS)) else @@ -771,6 +772,7 @@ clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN) cd $(NNVM_PATH); $(MAKE) clean; cd - cd $(TVM_PATH); $(MAKE) clean; cd - cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - + $(RM) libsample_lib.so endif clean_all: clean diff --git a/example/lib_api/Makefile b/example/extensions/lib_api/Makefile similarity index 80% rename from example/lib_api/Makefile rename to example/extensions/lib_api/Makefile index e5893c8065c4..cb529390b77f 100644 --- a/example/lib_api/Makefile +++ b/example/extensions/lib_api/Makefile @@ -16,16 +16,16 @@ # under the License. all: - g++ -shared -fPIC mylib.cc -o mylib.so -I ../../include/mxnet + g++ -std=c++11 -shared -fPIC init_lib.cc -o libinit_lib.so -I ../../../include/mxnet test: - g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../include/mxnet + g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../../include/mxnet windows: - cl /LD mylib.cc + cl /LD init_lib.cc win_test: cl libtest.cc clean: - rm -rf mylib.so libtest + rm -rf *.so libtest diff --git a/example/lib_api/mylib.cc b/example/extensions/lib_api/init_lib.cc similarity index 91% rename from example/lib_api/mylib.cc rename to example/extensions/lib_api/init_lib.cc index e67560a87f3d..6a040ffa2ecb 100644 --- a/example/lib_api/mylib.cc +++ b/example/extensions/lib_api/init_lib.cc @@ -19,19 +19,19 @@ /*! * Copyright (c) 2015 by Contributors - * \file mylib.cc + * \file init_lib.cc * \brief Sample library file */ #include #include "lib_api.h" -int initialize(int version) { +MXReturnValue initialize(int version) { if (version >= 10400) { std::cout << "MXNet version " << version << " supported" << std::endl; - return 1; + return MX_SUCCESS; } else { std::cout << "MXNet version " << version << " not supported" << std::endl; - return 0; + return MX_FAIL; } } diff --git a/example/lib_api/libtest.cc b/example/extensions/lib_api/libtest.cc similarity index 95% rename from example/lib_api/libtest.cc rename to example/extensions/lib_api/libtest.cc index 8bdf36c05d37..0b2c6f64789c 100644 --- a/example/lib_api/libtest.cc +++ b/example/extensions/lib_api/libtest.cc @@ -40,10 +40,10 @@ int main(void) { // Get a handle to the library. #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) HINSTANCE handle; - handle = LoadLibrary(TEXT("mylib.dll")); + handle = LoadLibrary(TEXT("libinit_lib.dll")); #else void *handle; - handle = dlopen("mylib.so", RTLD_LAZY); + handle = dlopen("libinit_lib.so", RTLD_LAZY); #endif if (!handle) { diff --git a/example/lib_api/test.py b/example/extensions/lib_api/test_loading.py similarity index 87% rename from example/lib_api/test.py rename to example/extensions/lib_api/test_loading.py index d73d85c02ced..d2fc2185716c 100644 --- a/example/lib_api/test.py +++ b/example/extensions/lib_api/test_loading.py @@ -26,6 +26,8 @@ import os if (os.name=='posix'): - mx.library.load('mylib.so') + path = os.path.abspath('libinit_lib.so') + mx.library.load(path) elif (os.name=='nt'): - mx.library.load('mylib.dll') + path = os.path.abspath('libinit_lib.dll') + mx.library.load(path) diff --git a/example/extensions/lib_custom_op/Makefile b/example/extensions/lib_custom_op/Makefile new file mode 100644 index 000000000000..66079a16a338 --- /dev/null +++ b/example/extensions/lib_custom_op/Makefile @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +all: subgraph_lib gemm_lib + +gemm_lib: + g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet + +subgraph_lib: + g++ -shared -fPIC -std=c++11 subgraph_lib.cc -o libsubgraph_lib.so -I ../../../include/mxnet + +clean: + rm -rf libsubgraph_lib.so libgemm_lib.so diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc new file mode 100644 index 000000000000..3835207e2a16 --- /dev/null +++ b/example/extensions/lib_custom_op/gemm_lib.cc @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file gemm_lib.cc + * \brief Sample 2D gemm custom operator implementation library file + */ + +#include +#include "lib_api.h" + +// main matrix multiplication routine +void gemm(const float* A, const float* B, float* C, + const unsigned n, const unsigned k, const unsigned m) { + unsigned i, j, kk; + for (i = 0; i < n; i++) { + for (j = 0; j < m; j++) { + C[i*m+j] = 0; + for (kk = 0; kk < k; kk++) { + C[i*m+j] += A[i*k+kk] * B[kk*m+j]; + } + } + } +} + +void transpose(const float* A, float* At, const unsigned n, const unsigned m) { + unsigned i, j; + for (i = 0; i < n; i++) { + for (j = 0; j < m; j++) { + At[i*m+j] = A[j*n+i]; + } + } +} + +/* + * Executes C = A * B + * inputs[0] = A; inputs[1] = B; outputs[0] = C + */ +MXReturnValue forward(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + // simple example of using runtime data type + if (inputs[0].dtype == kFloat32) { + typedef float DType; + // extract data pointers from tensors + // if using dltensor repr, below lines can be changed to something like + // DType* A = reinterpret_cast(inputs[0].dltensor.data); + DType* A = inputs[0].data(); + DType* B = inputs[1].data(); + DType* C = outputs[0].data(); + // set tensor shapes + unsigned n = inputs[0].shape[0]; + unsigned k = inputs[0].shape[1]; + unsigned m = inputs[1].shape[1]; + + gemm(A, B, C, n, k, m); + } + return MX_SUCCESS; +} + +/* + * Executes dA = dC * B.T; Executes dB = A.T * dC + ***** gradient inputs + * inputs[0] = dC + ***** original inputs + * inputs[1] = A; inputs[2] = B + ***** original outputs + * inputs[3] = C + ***** gradient outputs + * outputs[0] = dA; outputs[1] = dB + */ +MXReturnValue backward(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + // extract data pointers from tensors + float* dC = inputs[0].data(); + float* A = inputs[1].data(); + float* B = inputs[2].data(); + float* dA = outputs[0].data(); + float* dB = outputs[1].data(); + // set tensor shapes + unsigned n = inputs[1].shape[0]; + unsigned k = inputs[1].shape[1]; + unsigned m = inputs[2].shape[1]; + // allocate temporary workspace memory through resource manager + // for multiple arrays better to request a big memory pool + void *workspace = res.alloc((k*n + m*k) * sizeof(float)); + float *At = static_cast(workspace); + float *Bt = static_cast(workspace) + (k*n); + + transpose(A, At, k, n); + transpose(B, Bt, m, k); + gemm(dC, Bt, dA, n, m, k); + gemm(At, dC, dB, k, n, m); + + return MX_SUCCESS; +} + +MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { + *num_in = 2; + *num_out = 1; + return MX_SUCCESS; +} + +MXReturnValue inferType(std::map attrs, + std::vector &intypes, + std::vector &outtypes) { + // validate inputs + if (intypes.size() != 2) { + std::cout << "Expected 2 inputs to inferType" << std::endl; + return MX_FAIL; + } + for (unsigned i = 0; i < intypes.size(); i++) { + if (intypes[i] != kFloat32) { + std::cout << "Expected input " << i << " to have float32 type" << std::endl; + return MX_FAIL; + } + } + + outtypes[0] = intypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferShape(std::map attrs, + std::vector> &inshapes, + std::vector> &outshapes) { + // validate inputs + if (inshapes.size() != 2) { + std::cout << "Expected 2 inputs to inferShape" << std::endl; + return MX_FAIL; + } + if (inshapes[0].size() != 2 || inshapes[1].size() != 2) { + std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl; + return MX_FAIL; + } + + unsigned n = inshapes[0][0]; + unsigned k = inshapes[0][1]; + unsigned kk = inshapes[1][0]; + unsigned m = inshapes[1][1]; + if (k != kk) { + std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl; + return MX_FAIL; + } + + outshapes[0] = {n, m}; + return MX_SUCCESS; +} + +REGISTER_OP(my_gemm) +.setForward(forward) +.setBackward(backward) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferShape(inferShape); + +/* ------------------------------------------------------------------------- */ + +class MyStatefulGemm : public CustomStatefulOp { + public: + explicit MyStatefulGemm(int count) : count(count) {} + + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + ++count; + std::cout << "Info: keyword + number of forward: " << count << std::endl; + std::map attrs; + return forward(attrs, inputs, outputs, op_res); + } + + MXReturnValue Backward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return backward(attrs, inputs, outputs, op_res); + } + + ~MyStatefulGemm() {} + + private: + int count; +}; + +MXReturnValue createOpState(std::map attrs, + CustomStatefulOp** op_inst) { + int count = 0; + if (attrs.count("test_kw") > 0) + count = std::stoi(attrs["test_kw"]); + *op_inst = new MyStatefulGemm(count); + std::cout << "Info: stateful operator created" << std::endl; + return MX_SUCCESS; +} + +MXReturnValue mutateInputs(std::map attrs, + std::vector &input_indices) { + // input_indices.push_back(1); // mark mutate input + return MX_SUCCESS; +} + +REGISTER_OP(state_gemm) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferShape(inferShape) +.setMutateInputs(mutateInputs) +.setCreateOpState(createOpState); + +MXReturnValue initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_custom_op/subgraph_lib.cc b/example/extensions/lib_custom_op/subgraph_lib.cc new file mode 100644 index 000000000000..8e7e8833745a --- /dev/null +++ b/example/extensions/lib_custom_op/subgraph_lib.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file subgraph_lib.cc + * \brief subgraph operator implementation library file + */ + +#include +#include "lib_api.h" + +MXReturnValue parseAttrs(std::map attrs, + int* num_in, int* num_out) { + *num_in = 1; + *num_out = 1; + if (attrs.count(SUBGRAPH_SYM_JSON)) { + // example of subgraph json parsing + JsonParser jp; + JsonVal val = jp.parse_to_json(attrs[SUBGRAPH_SYM_JSON]); + int input = 0; + for (auto &item : val.map[JsonVal("nodes")].list) { + if (item.map[JsonVal("op")].str == "null") + input++; + } + int output = val.map[JsonVal("heads")].list.size(); + *num_in = input; + *num_out = output; + } + return MX_SUCCESS; +} + +MXReturnValue inferType(std::map attrs, + std::vector &intypes, + std::vector &outtypes) { + outtypes[0] = intypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferShape(std::map attrs, + std::vector> &inshapes, + std::vector> &outshapes) { + outshapes[0] = inshapes[0]; + return MX_SUCCESS; +} + +class MyStatefulOp : public CustomStatefulOp { + public: + explicit MyStatefulOp(std::string sym) : subgraph_sym(sym) {} + + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::cout << "Info: subgraph symbol is: " << std::endl; + std::cout << subgraph_sym << std::endl; + float* in_data = inputs[0].data(); + float* out_data = outputs[0].data(); + std::cout << "Info: output is: " << std::endl; + for (int i = 0; i < inputs[0].size(); i++) { + out_data[i] = in_data[i]; + } + return MX_SUCCESS; + } + + private: + std::string subgraph_sym; +}; + +MXReturnValue createOpState(std::map attrs, + CustomStatefulOp** op_inst) { + std::string serialized_subgraph = "[empty]"; + // MXNet subgraph is stored as Symbol in operator node attrs subgraphs field + // custom subgraph is stored as json string in custom operator attrs map entry + if (attrs.count(SUBGRAPH_SYM_JSON)) { + // user can now parse json and run other custom ops inside subgraph + serialized_subgraph = attrs[SUBGRAPH_SYM_JSON]; + } + *op_inst = new MyStatefulOp(serialized_subgraph); + std::cout << "Info: stateful operator created" << std::endl; + return MX_SUCCESS; +} + +REGISTER_OP(_custom_subgraph_op) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferShape(inferShape) +.setCreateOpState(createOpState); + +MXReturnValue initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_custom_op/test_gemm.py b/example/extensions/lib_custom_op/test_gemm.py new file mode 100644 index 000000000000..9a588255032f --- /dev/null +++ b/example/extensions/lib_custom_op/test_gemm.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks dynamic loading of custom library into MXNet +# and checks end to end compute of a simple 2D gemm custom op + +import mxnet as mx +import os + +#load library +if (os.name=='posix'): + path = os.path.abspath('libgemm_lib.so') + mx.library.load(path) +elif (os.name=='nt'): + path = os.path.abspath('libgemm_lib.dll') + mx.library.load(path) + +a = mx.nd.array([[1,2,3],[4,5,6]]) +b = mx.nd.array([[7],[8],[9]]) + +print("--------start ndarray compute---------") +print(mx.nd.my_gemm(a,b)) +print("--------") +print(mx.nd.state_gemm(a,b,test_kw=100)) + +print("--------start symbolic compute--------") +s = mx.sym.Variable('s') +t = mx.sym.Variable('t') +c = mx.sym.my_gemm(s,t) +d = mx.sym.state_gemm(s,t,test_kw=200) + +in_grad = [mx.nd.empty((2,3)),mx.nd.empty((3,1))] +in_grad2 = [mx.nd.empty((2,3)),mx.nd.empty((3,1))] + +exe = c.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad) +exe2 = d.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad2) + +out = exe.forward() +print(out) +print("-------") + +out2 = exe2.forward() +out2 = exe2.forward() +print(out2) +print("-------") + +# baseline forward +e = mx.sym.linalg.gemm2(s,t) +in_grad3 = [mx.nd.empty((2,3)),mx.nd.empty((3,1))] +exe3 = e.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad3) +out3 = exe3.forward() +print(out3) + +print("--------start backward compute--------") +out_grad = mx.nd.ones((2,1)) +exe.backward([out_grad]) +print(in_grad) +print("-------") +exe2.backward([out_grad]) +print(in_grad2) +print("-------") +exe3.backward([out_grad]) +print(in_grad3) diff --git a/example/extensions/lib_custom_op/test_subgraph.py b/example/extensions/lib_custom_op/test_subgraph.py new file mode 100644 index 000000000000..2625b13f6794 --- /dev/null +++ b/example/extensions/lib_custom_op/test_subgraph.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks if dynamic loading of library into MXNet is successful +# and checks the end of end computation of custom operator + +import mxnet as mx +import os, ctypes +from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle + +# load library +if (os.name=='posix'): + path = os.path.abspath('libsubgraph_lib.so') + mx.library.load(path) +elif (os.name=='nt'): + path = os.path.abspath('libsubgraph_lib.dll') + mx.library.load(path) + +a = mx.sym.var('a') +b = mx.sym.var('b') +c = a + b +d = mx.sym.exp(c) +ret = mx.sym.log(d) + +op_names = ['exp','log'] +out = SymbolHandle() + +check_call(_LIB.MXBuildSubgraphByOpNames(ret.handle, + c_str('default'), + mx_uint(len(op_names)), + c_str_array(op_names), + ctypes.byref(out))) +partitioned_sym = mx.sym.Symbol(out) +json_sym = partitioned_sym.tojson() + +mystr = json_sym.replace("_CachedOp","_custom_subgraph_op") +mysym = mx.sym.load_json(mystr) + +exe = mysym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) +out = exe.forward() +print(out) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index ca3b2952eafa..290a63518373 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -18,33 +18,999 @@ */ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2019 by Contributors * \file lib_api.h * \brief APIs to interact with libraries + * This API specifies function prototypes to + * register custom ops for library authors */ + #ifndef MXNET_LIB_API_H_ #define MXNET_LIB_API_H_ +#include +#include +#include +#include +#include +#include +#include +#include + +#define MX_LIBRARY_VERSION 1 + +/* + * Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h + */ +#ifndef DLPACK_VERSION +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current version of dlpack */ +#define DLPACK_VERSION 020 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { + #endif + /*! + * \brief The device type in DLContext. + */ + typedef enum { + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLGPU = 2, + /*! + * \brief Pinned CUDA GPU device by cudaMallocHost + * \note kDLCPUPinned = kDLCPU | kDLGPU + */ + kDLCPUPinned = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + } DLDeviceType; + + /*! + * \brief A Device context for Tensor and operator. + */ + typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! \brief The device index */ + int device_id; + } DLContext; + + /*! + * \brief The type code options DLDataType. + */ + typedef enum { + kDLInt = 0U, + kDLUInt = 1U, + kDLFloat = 2U, + } DLDataTypeCode; + + /*! + * \brief The data type the tensor can hold. + * + * Examples + * - float: type_code = 2, bits = 32, lanes=1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 + * - int8: type_code = 0, bits = 8, lanes=1 + */ + typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; + } DLDataType; + + /*! + * \brief Plain C Tensor object, does not manage memory. + */ + typedef struct { + /*! + * \brief The opaque data pointer points to the allocated data. This will be + * CUDA device pointer or cl_mem handle in OpenCL. This pointer is always + * aligns to 256 bytes as in CUDA. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + */ + void* data; + /*! \brief The device context of the tensor */ + DLContext ctx; + /*! \brief Number of dimensions */ + int ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes) + * can be NULL, indicating tensor is compact and row-majored. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; + } DLTensor; +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif + +/*! + * \brief Tensor data type, consistent with mshadow data type + */ +enum MXDType { + kFloat32 = 0, + kFloat64 = 1, + kFloat16 = 2, + kUint8 = 3, + kInt32 = 4, + kInt8 = 5, + kInt64 = 6, +}; + +enum MXReturnValue { + MX_FAIL = 0, + MX_SUCCESS = 1, +}; + +/*! + * \brief Tensor data structure used by custom operator + */ +struct MXTensor { + MXTensor() : data_ptr(NULL) {} + + MXTensor(void *data_ptr, const std::vector &shape, MXDType dtype) + : data_ptr(data_ptr), shape(shape), dtype(dtype) {} + + /*! \brief populate DLTensor fields */ + void setDLTensor() { + dltensor.data = data_ptr; + dltensor.ctx.device_type = kDLCPU; + dltensor.ctx.device_id = 0; + dltensor.ndim = shape.size(); + dltensor.shape = const_cast(shape.data()); + dltensor.strides = NULL; + dltensor.byte_offset = 0; + dltensor.dtype.lanes = 1; + switch (dtype) { + case kFloat32: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 32; + break; + case kFloat64: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 64; + break; + case kFloat16: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 16; + break; + case kUint8: + dltensor.dtype.code = kDLUInt; + dltensor.dtype.bits = 8; + break; + case kInt32: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 32; + break; + case kInt8: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 8; + break; + case kInt64: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 64; + break; + default: + dltensor.dtype.code = 0; + dltensor.dtype.bits = 0; + throw std::runtime_error("Error! Invalid dtype flag: " + + std::to_string(static_cast(dtype)) + + " when constructing MXTensor"); + } + } + + /*! \brief helper function to cast data pointer */ + template + inline data_type* data() { + return reinterpret_cast(data_ptr); + } + + /*! \brief helper function to get data size */ + inline int64_t size() { + int64_t size = 1; + for (unsigned int i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; + } + + // data is flatten 1D repr of tensor, elements are in continuous memory + // user can access each element using the shape of tensor + void *data_ptr; + + // shape is in [2,3,4] format to represent high-dim tensor + std::vector shape; + + // type can only be MXDType enum types + MXDType dtype; + + // corresponding DLTensor repr of MXTensor + // easy way to reuse functions taking DLTensor + DLTensor dltensor; +}; + +/*! + * \brief resource malloc function to allocate memory inside Forward/Backward functions + */ +typedef void* (*xpu_malloc_t)(void*, int); + /*! - * \brief Following are the APIs implemented in the external library + * \brief provide resource APIs memory allocation mechanism to Forward/Backward functions + */ +class OpResource { + public: + OpResource(xpu_malloc_t cm, void* ca) : cpu_malloc(cm), cpu_alloc(ca) {} + + /*! \brief allocate memory controlled by MXNet */ + void* alloc(int size) { + return cpu_malloc(cpu_alloc, size); + } + + private: + xpu_malloc_t cpu_malloc; + void* cpu_alloc; +}; + +/*! + * \brief Json utility to parse serialized subgraph symbol + */ +/*! \brief Macro to help passing serialized subgraph through attribute dict */ +#define SUBGRAPH_SYM_JSON "subgraph_sym_json" + +/*! \brief Types of JSON objects */ +enum JsonType {ERR, STR, NUM, LIST, MAP}; + +/*! \brief definition of JSON objects */ +struct JsonVal { + JsonVal() : type(ERR), num(-1), str("") {} // default constructor + // construct a JSON object by type + explicit JsonVal(JsonType t) : type(t), num(-1), str("") {} + // construct a string JSON object + explicit JsonVal(std::string s) : type(STR), num(-1), str(s) {} + // construct a number JSON object + explicit JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {} + // complex constructor + JsonVal(JsonType t, int n, std::string s) : type(t), num(n), str(s) {} + bool operator<(const JsonVal &o) const { + // for string JSON objects compare the string + if (type == STR) return type == o.type && str < o.str; + // for number JSON objects compare the number + if (type == NUM) return type == o.type && num < o.num; + // for list JSON objects, compare the size of list, and then each object in the list + if (type == LIST) { + if (list.size() != o.list.size()) return false; + for (unsigned int i=0; i< list.size(); i++) + if (list[i] < o.list[i]) + return false; // if we find an object that doesnt match return + return true; // all objects in lists matched + } + // for map JSON objects, compare the size of map, and then each key/value in the maps + if (type == MAP) { + if (map.size() != o.map.size()) return false; + for (auto &item : map) { + // if one map is missing a key in another return + if (o.map.find(item.first) == o.map.end()) return false; + if (item.second < o.map.at(item.first)) return false; + } + return true; + } + return type < o.type; + } + JsonType type; + int num; + std::string str; + std::vector list; + std::map map; +}; + +/*! \brief functions used for parsing JSON */ +struct JsonParser { + JsonVal parse_to_json(std::string json) { + unsigned int idx = 0; + return parse(json, &idx); + } + void print_json_val(JsonVal val) { + std::cout << json_val_string(val) << std::endl; + } + // debug function to convert a JSON object to a string + std::string json_val_string(const JsonVal &val) { + std::string ret; + switch (val.type) { + case ERR: + ret = "json(Error)"; + break; + case STR: + ret = "json(STR:" + val.str + ")"; + break; + case NUM: + ret = "json(INT:" + val.str + ")"; + break; + case LIST: + ret = "json(LIST:["; + for (auto &item : val.list) + ret += json_val_string(item) + ","; + ret += "])"; + break; + case MAP: + ret = "json(MAP:{"; + for (auto &item : val.map) + ret += json_val_string(item.first) + " : " + json_val_string(item.second) + ","; + ret += "})"; + break; + } + return ret; + } + // parse a string JSON object + JsonVal parse_string(std::string json, unsigned int* idx) { + JsonVal ret(STR); + while (*idx < json.size()) { + if (json[*idx] == '"') { + ++(*idx); + return ret; + } else { + ret.str += json[*idx]; + ++(*idx); + } + } + std::cout << "Error! Unable to parse string" << std::endl; + return JsonVal(); + } + // parse a number JSON object + JsonVal parse_num(std::string json, unsigned int* idx) { + JsonVal ret(NUM); + while (*idx < json.size()) { + if (json[*idx] >= '0' && json[*idx] <= '9') { + ret.str += json[*idx]; + ++(*idx); + } else { + break; + } + } + ret.num = std::stoi(ret.str); + return ret; + } + // parse a list of JSON objects + JsonVal parse_list(std::string json, unsigned int* idx) { + JsonVal ret(LIST); + while (*idx < json.size()) { + if (json[*idx] == ']') { + ++(*idx); + return ret; + } else { + JsonVal item = parse(json, idx); + if (item.type != ERR) + ret.list.push_back(item); + } + } + std::cout << "Error! Unable to parse list" << std::endl; + return JsonVal(); + } + // parse a map of JSON objects + JsonVal parse_map(std::string json, unsigned int* idx) { + JsonVal ret(MAP), key; + while (*idx < json.size()) { + if (json[*idx] == '}') { + ++(*idx); + return ret; + } else { + JsonVal item = parse(json, idx); + if (key.type == ERR) { + key = item; + } else { + ret.map[key] = item; + key.type = ERR; + } + } + } + std::cout << "Error! Unable to parse map" << std::endl; + return JsonVal(); + } + // generic parse function + JsonVal parse(std::string json, unsigned int *idx) { + JsonVal ret; + while (*idx < json.size()) { + if (json[*idx] == '"') { + ++(*idx); + ret = parse_string(json, idx); + } else if (json[*idx] >= '0' && json[*idx] <= '9') { + ret = parse_num(json, idx); + } else if (json[*idx] == '[') { + ++(*idx); + ret = parse_list(json, idx); + } else if (json[*idx] == '{') { + ++(*idx); + ret = parse_map(json, idx); + } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;} + if (ret.type != ERR) return ret; + ++(*idx); + } + return ret; + } +}; + +/*! + * \brief An abstract class for library author creating stateful op + * custom library should override Forward and destructor, and has an + * option to implement Backward + */ +class CustomStatefulOp { + public: + virtual MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) = 0; + virtual MXReturnValue Backward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::cout << "Error! Operator does not support backward" << std::endl; + return MX_FAIL; + } +}; + +/*! \brief StatefulOp wrapper class to pass to backend OpState */ +class CustomStatefulOpWrapper { + public: + explicit CustomStatefulOpWrapper(CustomStatefulOp* inst) : instance(inst) {} + CustomStatefulOp* get_instance() { return instance; } + private: + CustomStatefulOp* instance; +}; + +/*! \brief Custom Operator function templates */ +typedef MXReturnValue (*fcomp_t)(std::map, + std::vector, std::vector, + OpResource res); +typedef MXReturnValue (*parseAttrs_t)(std::map, + int*, int*); +typedef MXReturnValue (*inferType_t)(std::map, + std::vector&, std::vector&); +typedef MXReturnValue (*inferShape_t)(std::map, + std::vector >&, + std::vector >&); +typedef MXReturnValue (*mutateInputs_t)(std::map, + std::vector&); +typedef MXReturnValue (*createOpState_t)(std::map, + CustomStatefulOp**); + +/*! + * \brief Class to hold custom operator registration + */ +class CustomOp { + public: + explicit CustomOp(const char* op_name) : name(op_name), + forward(NULL), backward(NULL), parse_attrs(NULL), infer_type(NULL), + infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL) {} + ~CustomOp() {} + CustomOp& setForward(fcomp_t fcomp) { + forward = fcomp; + return *this; + } + CustomOp& setBackward(fcomp_t fcomp) { + backward = fcomp; + return *this; + } + CustomOp& setParseAttrs(parseAttrs_t func) { + parse_attrs = func; + return *this; + } + CustomOp& setInferType(inferType_t func) { + infer_type = func; + return *this; + } + CustomOp& setInferShape(inferShape_t func) { + infer_shape = func; + return *this; + } + CustomOp& setMutateInputs(mutateInputs_t func) { + mutate_inputs = func; + return *this; + } + CustomOp& setCreateOpState(createOpState_t func) { + create_opstate = func; + return *this; + } + + /*! \brief operator name */ + const char* name; + /*! \brief operator functions */ + fcomp_t forward; + fcomp_t backward; + parseAttrs_t parse_attrs; + inferType_t infer_type; + inferShape_t infer_shape; + mutateInputs_t mutate_inputs; + createOpState_t create_opstate; +}; + +/*! + * \brief Registry class to registers things (ops, properties) + * Singleton class + */ +template +class Registry { + public: + /*! + * \brief get singleton pointer to class + * \returns pointer to class + */ + static Registry* get() { + static Registry inst; + return &inst; + } + /*! + * \brief add a new entry + * \returns new object associated with registered name + */ + T& add(const char* name) { + T *entry = new T(name); + entries.push_back(entry); + return *entry; + } + int size() { + return entries.size(); + } + T& get(int idx) { + return *(entries[idx]); + } + + private: + /*! \brief constructor */ + Registry() {} + /*! \brief destructor */ + ~Registry() {} + /*! \brief map of entries in registry */ + std::vector entries; +}; + +/*! + * \brief Macros to help with string concat + * Annoyingly, the concat_ and concat macros are necessary to + * be able to use __COUNTER__ in an identifier name + */ +#define MX_STR_CONCAT_(__a, __b) __a ## __b +#define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b) + +/*! \brief convert a token to a string */ +#define MX_STRINGIFY(x) #x +#define MX_TOSTRING(x) MX_STRINGIFY(x) + +/*! \brief declare a variable with custom name */ +#define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _ +#define MX_REGISTER_DEF_(Name) CustomOp MX_REGISTER_NAME_(Name) + +/*! \brief assign a var to a value */ +#define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \ + Registry::get()->add(MX_TOSTRING(Name)) + +/* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */ + +/*! + * \brief Following are the C type APIs implemented in the external library * Each API has a #define string that is used to lookup the function in the library * Followed by the function declaration */ +#define MXLIB_OPREGSIZE_STR "_opRegSize" +typedef int (*opRegSize_t)(void); + +#define MXLIB_OPREGGET_STR "_opRegGet" +typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*, + parseAttrs_t*, inferType_t*, + inferShape_t*, mutateInputs_t*, + createOpState_t*); + +#define MXLIB_OPCALLFREE_STR "_opCallFree" +typedef int (*opCallFree_t)(void*); + +#define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" +typedef int (*opCallParseAttrs_t)(parseAttrs_t, const char* const*, const char* const*, int, + int*, int*); + +#define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape" +typedef int (*opCallInferShape_t)(inferShape_t, const char* const*, const char* const*, int, + unsigned int**, int*, int, + unsigned int***, int**, int); + +#define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" +typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* const*, int, + int*, int, int*, int); + +#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" +typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int, + const int64_t**, int*, void**, int*, int, + const int64_t**, int*, void**, int*, int, + xpu_malloc_t, void*); + +#define MXLIB_OPCALLBKWD_STR "_opCallBackward" +typedef int (*opCallBkwd_t)(fcomp_t, const char* const*, const char* const*, int, + const int64_t**, int*, void**, int*, int, + const int64_t**, int*, void**, int*, int, + xpu_malloc_t, void*); + +#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" +typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const char* const*, int, + int**, int*); + +#define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" +typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const char* const*, int, + void**); + +#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" +typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, int, + const int64_t**, int*, void**, int*, int, + xpu_malloc_t, void*); + #define MXLIB_INITIALIZE_STR "initialize" typedef int (*initialize_t)(int); +#define MXLIB_OPVERSION_STR "_opVersion" +typedef int (*opVersion_t)(); + extern "C" { - /*! - * \brief Checks if the MXNet version is supported by the library. - * If supported, initializes the library. - * \param version MXNet version number passed to library and defined as: - * MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) - * \return Non-zero value on error i.e. library incompatible with passed MXNet version - */ + /*! \brief returns MXNet library version */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opVersion() { + return MX_LIBRARY_VERSION; + } + + /*! \brief returns number of ops registered in this library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opRegSize() { + return Registry::get()->size(); + } + + /*! \brief returns operator registration at specified index */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) void __cdecl +#else + void +#endif + _opRegGet(int idx, const char** name, fcomp_t* fcomp, fcomp_t* fgrad, + parseAttrs_t* parse, inferType_t* type, + inferShape_t* shape, mutateInputs_t* mutate, + createOpState_t* create_op) { + CustomOp op = Registry::get()->get(idx); + *name = op.name; + *fcomp = op.forward; + *fgrad = op.backward; + *parse = op.parse_attrs; + *type = op.infer_type; + *shape = op.infer_shape; + *mutate = op.mutate_inputs; + *create_op = op.create_opstate; + } + + /*! \brief calls free from the external library for library allocated arrays */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) void __cdecl +#else + void +#endif + _opCallFree(void* ptr) { + free(ptr); + } + + /*! \brief returns status of calling parse attributes function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys, + const char* const* vals, int num, + int* num_in, int* num_out) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + return parseAttrs(attrs, num_in, num_out); + } + + /*! \brief returns status of calling inferShape function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallInferShape(inferShape_t inferShape, const char* const* keys, + const char* const* vals, int num, + unsigned int** inshapes, int* indims, int num_in, + unsigned int*** outshapes, int** outdims, int num_out) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of shapes for inputs + std::vector > in_shapes(num_in); + for (int i = 0; i < num_in; i++) { + for (int j = 0; j < indims[i]; j++) { + in_shapes[i].push_back(inshapes[i][j]); + } + } + + // create a vector of shapes for outputs + std::vector > out_shapes(num_out); + + int retval = inferShape(attrs, in_shapes, out_shapes); + if (!retval) + return retval; + + // allocate space for output dims, shape + *outdims = static_cast(malloc (num_out * sizeof(int))); + *outshapes = static_cast(malloc (num_out * sizeof(unsigned*))); + + // copy output shapes + for (int i = 0; i < num_out; i++) { + (*outdims)[i] = out_shapes[i].size(); + (*outshapes)[i] = static_cast(malloc ((*outdims)[i] * sizeof(unsigned))); + for (int j = 0; j < indims[i]; j++) { + (*outshapes)[i][j] = out_shapes[i][j]; + } + } + + return retval; + } + + /*! \brief returns status of calling inferType function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallInferType(inferType_t inferType, const char* const* keys, + const char* const* vals, int num, + int* intypes, int num_in, int* outtypes, int num_out) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of types for inputs + std::vector in_types(num_in); + for (int i = 0; i < num_in; i++) { + in_types[i] = intypes[i]; + } + + // create a vector of types for outputs + std::vector out_types(num_out, -1); + + int retval = inferType(attrs, in_types, out_types); + if (!retval) + return retval; + + // copy output types + for (int i = 0; i < num_out; i++) { + outtypes[i] = out_types[i]; + } + + return retval; + } + + /*! \brief returns status of calling Forward/Backward function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallFCompute(fcomp_t fcomp, const char* const* keys, + const char* const* vals, int num, + const int64_t** inshapes, int* indims, + void** indata, int* intypes, int num_in, + const int64_t** outshapes, int* outdims, + void** outdata, int* outtypes, int num_out, + xpu_malloc_t cpu_malloc, void* cpu_alloc) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of tensors for inputs + std::vector inputs(num_in); + for (int i = 0; i < num_in; i++) { + inputs[i].data_ptr = indata[i]; + inputs[i].dtype = (MXDType)intypes[i]; + for (int j = 0; j < indims[i]; j++) { + inputs[i].shape.push_back(inshapes[i][j]); + } + inputs[i].setDLTensor(); + } + + // create a vector of tensors for outputs + std::vector outputs(num_out); + for (int i = 0; i < num_out; i++) { + outputs[i].data_ptr = outdata[i]; + outputs[i].dtype = (MXDType) outtypes[i]; + for (int j = 0; j < outdims[i]; j++) { + outputs[i].shape.push_back(outshapes[i][j]); + } + outputs[i].setDLTensor(); + } + + OpResource res(cpu_malloc, cpu_alloc); + + return fcomp(attrs, inputs, outputs, res); + } + + /*! \brief returns status of calling mutateInputs function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys, + const char* const* vals, int num, + int** mutate_indices, int* indices_size) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of mutate input indices + std::vector mut_ind; + + int retval = mutate(attrs, mut_ind); + if (!retval) + return retval; + + // output the input indices + *indices_size = mut_ind.size(); + *mutate_indices = static_cast(malloc (*indices_size * sizeof(int))); + for (int i = 0; i < *indices_size; i++) { + (*mutate_indices)[i] = mut_ind[i]; + } + + return retval; + } + + /*! \brief returns status of calling createStatefulOp function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallCreateOpState(createOpState_t create_op, const char* const* keys, + const char* const* vals, int num, + void** state_op) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // void pointer to hold custom state op instance created in custom library + CustomStatefulOp** op_ptr = reinterpret_cast(state_op); + return create_op(attrs, op_ptr); + } + + /*! \brief returns status of calling Stateful Forward/Backward for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallFStatefulCompute(bool is_forward, void* state_op, + const int64_t** inshapes, int* indims, + void** indata, int* intypes, int num_in, + const int64_t** outshapes, int* outdims, + void** outdata, int* outtypes, int num_out, + xpu_malloc_t cpu_malloc, void* cpu_alloc) { + // create a vector of tensors for inputs + std::vector inputs(num_in); + for (int i = 0; i < num_in; i++) { + inputs[i].data_ptr = indata[i]; + inputs[i].dtype = (MXDType)intypes[i]; + for (int j = 0; j < indims[i]; j++) { + inputs[i].shape.push_back(inshapes[i][j]); + } + inputs[i].setDLTensor(); + } + + // create a vector of tensors for outputs + std::vector outputs(num_out); + for (int i = 0; i < num_out; i++) { + outputs[i].data_ptr = outdata[i]; + outputs[i].dtype = (MXDType) outtypes[i]; + for (int j = 0; j < outdims[i]; j++) { + outputs[i].shape.push_back(outshapes[i][j]); + } + outputs[i].setDLTensor(); + } + OpResource res(cpu_malloc, cpu_alloc); + CustomStatefulOp* op_ptr = reinterpret_cast(state_op); + if (is_forward) { + return op_ptr->Forward(inputs, outputs, res); + } + return op_ptr->Backward(inputs, outputs, res); + } + + /*! + * \brief Checks if the MXNet version is supported by the library. + * If supported, initializes the library. + * \param version MXNet version number passed to library and defined as: + * MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) + * \return Non-zero value on error i.e. library incompatible with passed MXNet version + */ #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl initialize(int); + __declspec(dllexport) MXReturnValue __cdecl #else - int initialize(int); + MXReturnValue #endif + initialize(int version); } #endif // MXNET_LIB_API_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 088ffc1ea9d0..3a35d333e479 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -27,7 +27,6 @@ from .util import is_np_shape, set_np_shape, np_shape, use_np_shape from .util import is_np_array, np_array, use_np_array, use_np from . import base -from . import library from . import contrib from . import ndarray from . import ndarray as nd @@ -87,6 +86,8 @@ from . import gluon +# Dynamic library module should be done after ndarray and symbol are initialized +from . import library from . import tvmop __version__ = base.__version__ diff --git a/python/mxnet/library.py b/python/mxnet/library.py index 9ebf2c2bc580..8ea0bc2ae0a5 100644 --- a/python/mxnet/library.py +++ b/python/mxnet/library.py @@ -19,8 +19,11 @@ """Library management API of mxnet.""" from __future__ import absolute_import import ctypes +import sys import os -from .base import _LIB, check_call, MXNetError +from .base import _LIB, check_call, MXNetError, _init_op_module +from .ndarray.register import _make_ndarray_function +from .symbol.register import _make_symbol_function def load(path): """Loads library dynamically. @@ -47,3 +50,21 @@ def load(path): byt_obj = path.encode('utf-8') chararr = ctypes.c_char_p(byt_obj) check_call(_LIB.MXLoadLib(chararr)) + + #regenerate operators + _init_op_module('mxnet', 'ndarray', _make_ndarray_function) + _init_op_module('mxnet', 'symbol', _make_symbol_function) + + #re-register mx.nd.op into mx.nd + mx_nd = sys.modules["mxnet.ndarray"] + mx_nd_op = sys.modules["mxnet.ndarray.op"] + for op in dir(mx_nd_op): + func = getattr(mx_nd_op, op) + setattr(mx_nd, op, func) + + #re-register mx.sym.op into mx.sym + mx_sym = sys.modules["mxnet.symbol"] + mx_sym_op = sys.modules["mxnet.symbol.op"] + for op in dir(mx_sym_op): + func = getattr(mx_sym_op, op) + setattr(mx_sym, op, func) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ba33084a026d..24374cf19cdc 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -49,9 +50,11 @@ #include "../initialize.h" #include "./c_api_common.h" #include "../operator/custom/custom-inl.h" +#include "../operator/operator_common.h" #include "../operator/tensor/matrix_op-inl.h" #include "../operator/tvmop/op_module.h" #include "../common/utils.h" +#include "nnvm/pass_functions.h" using namespace mxnet; @@ -92,16 +95,593 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, // NOTE: return value is added in API_END -// Loads library and initializes it +/*! + * \brief Loads dynamic library and initializes it + * \param path library path + */ int MXLoadLib(const char *path) { API_BEGIN(); void *lib = LibraryInitializer::Get()->lib_load(path); if (!lib) LOG(FATAL) << "Unable to load library"; + // check that library and MXNet use same version of library API + opVersion_t opVersion = get_func(lib, const_cast(MXLIB_OPVERSION_STR)); + int libVersion = opVersion(); + if (MX_LIBRARY_VERSION != libVersion) + LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version (" + << MX_LIBRARY_VERSION << ")"; + + // initialize library by passing MXNet version initialize_t initialize = get_func(lib, const_cast(MXLIB_INITIALIZE_STR)); if (!initialize(static_cast(MXNET_VERSION))) LOG(FATAL) << "Library failed to initialize"; + + // get C type interface functions + opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); + + opCallParseAttrs_t callParseAttrs = + get_func(lib, const_cast(MXLIB_OPCALLPARSEATTRS_STR)); + + opCallInferShape_t callInferShape = + get_func(lib, const_cast(MXLIB_OPCALLINFERSHAPE_STR)); + + opCallInferType_t callInferType = + get_func(lib, const_cast(MXLIB_OPCALLINFERTYPE_STR)); + + opCallFComp_t callFComp = + get_func(lib, const_cast(MXLIB_OPCALLFCOMP_STR)); + + opCallMutateInputs_t callMutateInputs = + get_func(lib, const_cast(MXLIB_OPCALLMUTATEINPUTS_STR)); + + opCallCreateOpState_t callCreateOpState = + get_func(lib, const_cast(MXLIB_OPCALLCREATEOPSTATE_STR)); + + opCallFStatefulComp_t callFStatefulComp = + get_func(lib, const_cast(MXLIB_OPCALLFSTATEFULCOMP_STR)); + + // get number of operators registered in the library + opRegSize_t opRegSize = get_func(lib, const_cast(MXLIB_OPREGSIZE_STR)); + int numOps = opRegSize(); + LOG(INFO) << "Found " << numOps << " operators in library"; + + /* + * Get all custom operators implementation from custom library + * loop and register each operator in the library to NNVM + */ + opRegGet_t opRegGet = get_func(lib, const_cast(MXLIB_OPREGGET_STR)); + for (int i = 0; i < numOps; i++) { + const char* name; + // function pointers holding implementation from custom library + fcomp_t fcomp_fp = nullptr; + parseAttrs_t parse_fp = nullptr; + inferType_t type_fp = nullptr; + inferShape_t shape_fp = nullptr; + // optional attributes + fcomp_t fgrad_fp = nullptr; + mutateInputs_t mutate_fp = nullptr; + createOpState_t create_opstate_fp = nullptr; + + // get custom operator implemenation from the dynamic library + opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp, + &mutate_fp, &create_opstate_fp); + + // validate custom operator functions from the dynamic library + CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name + << "' custom op, Forward or CreateOpState function was not set."; + CHECK(parse_fp != nullptr) << "Error loading '" << name + << "' custom op, ParseAttrs function was not set."; + CHECK(type_fp != nullptr) << "Error loading '" << name + << "' custom op, InferType function was not set."; + CHECK(shape_fp != nullptr) << "Error loading '" << name + << "' custom op, InferShape function was not set."; + + LOG(INFO) << "\tOp[" << i << "] " << name; + std::string name_str(name); + + /* + * Below are a series of lambda functions that will be registered in the NNVM op registration + * Each one has the standard MXNet signature and converts to types supported by externally + * registered operators. + */ + + // lambda function to call parse attributes + auto attr_parser = [=](const NodeAttrs* attrs) { + // convert attributes to vector of char + std::vector attr_keys, attr_vals; + for (auto kv : attrs->dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + // convert subgraph symbol from node attributes to char* + std::string subgraph_json; + if (!attrs->subgraphs.empty()) { + nnvm::Graph g; + g.outputs = attrs->subgraphs[0].get()->outputs; + subgraph_json = nnvm::pass::SaveJSON(g); + attr_keys.push_back(SUBGRAPH_SYM_JSON); + attr_vals.push_back(subgraph_json.c_str()); + } + + int num_in = -1; + int num_out = -1; + CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out)) + << "Error calling ParseAttrs for custom operator '" << name_str << "'"; + + // return type void + }; + + // lambda function to call parse attributes and return the number of inputs + auto num_inputs = [=](const NodeAttrs& attrs) { + // convert attributes to vector of char + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + int num_in = -1; + int num_out = -1; + CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out)) + << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str << "'"; + + return num_in; + }; + + // lambda function to call parse attributes and return the number of outputs + auto num_outputs = [=](const NodeAttrs& attrs) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + int num_in = -1; + int num_out = -1; + CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out)) + << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'"; + + return num_out; + }; + + // lambda function to call parse attributes and return the number of inputs and outputs + // for backward computation + auto num_inouts = [=](const NodeAttrs& attrs) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + int num_in = -1; + int num_out = -1; + CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out)) + << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'"; + + return num_in + num_out; + }; + + // lambda function to call infer shape + auto infer_shape = [=] (const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + std::vector inshapes(in_shape->size()); + std::vector indims(in_shape->size()); + + // determine amount of memory needed to store all the input shapes + size_t buff_size = 0; + for (const auto& i : *in_shape) buff_size += i.ndim(); + + // copy input shapes from ShapeVector to raw memory layout + std::vector inbuff(buff_size); + uint32_t *ptr = inbuff.data(); + for (size_t i = 0; i < in_shape->size(); ++i) { + inshapes[i] = ptr; + indims[i] = (*in_shape)[i].ndim(); + for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) { + *ptr = static_cast((*in_shape)[i][j]); + } + } + + // output shapes will be allocated by infer shape function + uint32_t** outshapes = nullptr; + int* outdims = nullptr; + + CHECK(callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + inshapes.data(), indims.data(), in_shape->size(), + &outshapes, &outdims, out_shape->size())) + << "Error calling InferShape for custom operator '" << name_str << "'"; + + std::vector out_shapes(out_shape->size()); + // determine amount of memory needed to store all the output shapes + buff_size = 0; + for (unsigned i = 0; i < out_shape->size(); i++) { + buff_size += outdims[i]; + } + + // copy output shapes from custom op memory to MXNet memory + std::vector outbuff(buff_size); + ptr = outbuff.data(); + for (unsigned i = 0; i < out_shape->size(); ++i) { + out_shapes[i] = ptr; + for (int j = 0; j < outdims[i]; ++j, ++ptr) { + *ptr = static_cast(outshapes[i][j]); + } + } + + // assign output shapes to ShapeVector + for (unsigned i = 0; i < out_shape->size(); ++i) { + SHAPE_ASSIGN_CHECK(*out_shape, i, + mxnet::TShape(out_shapes[i], out_shapes[i]+outdims[i])); + } + + // free memory used by custom op to allocate shapes/dims + callFree(outdims); + for (unsigned i = 0; i < out_shape->size(); i++) { + callFree(outshapes[i]); + } + callFree(outshapes); + + return true; + }; + + // lambda function to call infer type + auto infer_type = [=] (const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + // copy input types from in_type + std::vector intypes(*in_type); + + // output types will be populated by inferType function + std::vector outtypes(out_type->size()); + + CHECK(callInferType(type_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + intypes.data(), in_type->size(), + outtypes.data(), out_type->size())) + << "Error calling InferType for custom operator '" << name_str << "'"; + + // copy and assign output types from custom op to MXNet memory + for (size_t i = 0; i < out_type->size(); i++) { + TYPE_ASSIGN_CHECK(*out_type, i, outtypes[i]); + } + + return true; + }; + + // lambda function to convert from external fcompute to internal MXNet types + auto fcomp_lambda = [=](fcomp_t fcomp_fp, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + std::vector in_data, out_data; + std::vector in_shapes, out_shapes; + std::vector in_dims, out_dims; + std::vector in_types, out_types; + + // convert input tensors to constituent parts + for (size_t i = 0; i < inputs.size(); i++) { + in_data.push_back(inputs[i].data().dptr_); + in_shapes.push_back(inputs[i].shape().data()); + in_dims.push_back(inputs[i].shape().ndim()); + in_types.push_back(inputs[i].dtype()); + } + + // convert output tensors to constituent parts + for (size_t i = 0; i < outputs.size(); i++) { + out_data.push_back(outputs[i].data().dptr_); + out_shapes.push_back(outputs[i].shape().data()); + out_dims.push_back(outputs[i].shape().ndim()); + out_types.push_back(outputs[i].dtype()); + } + + // get memory resource + const Resource &resource = ctx.requested[0]; + mshadow::Stream *cpu_stream = ctx.get_stream(); + + // create lambda that captures stream & resource objects + // this temp workspace holds memory allocated by custom library via OpResource + auto cpu_alloc = [&](int size) { + mshadow::Tensor workspace = + resource.get_space_typed(mshadow::Shape1(size), cpu_stream); + return workspace.dptr_; + }; + + // create lambda without captures so that we can cast it to function pointer + // this needs to be a lambda function so that we can do the decltype cast + typedef decltype(cpu_alloc) alloc_type; + auto cpu_malloc = [](void* _cpu_alloc, int size) { + // cast the void* argument to the type for the cpu_alloc lambda function + alloc_type* cpualloc = static_cast(_cpu_alloc); + // call cpu_alloc to actually allocate memory and get the pointer + void* ptr = (*cpualloc)(size); + return ptr; + }; + + // call fcompute function + CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + in_shapes.data(), in_dims.data(), in_data.data(), + in_types.data(), in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), + out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc)) + << "Error calling FCompute for custom operator '" << name_str << "'"; + + // return type void + }; + + auto forward_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + return fcomp_lambda(fcomp_fp, attrs, ctx, inputs, req, outputs); + }; + + auto backward_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + return fcomp_lambda(fgrad_fp, attrs, ctx, inputs, req, outputs); + }; + + // lambda function to convert from external mutate_inputs to internal MXNet types + auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + // C type placeholder for mutate input indices vector + int* mutate_indices = nullptr; + int indices_size = 0; + + // call mutate inputs function + CHECK(callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &mutate_indices, &indices_size)) + << "Error calling MutateInputs for custom operator '" << name_str << "'"; + + std::vector mutate_indices_list(indices_size); + for (int i=0; i < indices_size; i++) { + mutate_indices_list[i] = static_cast(mutate_indices[i]); + } + + return mutate_indices_list; + }; + + // lambda function to set storage types + auto infer_storage_type = [=](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + // TODO(ziyimu): remove this dense enforce check after supporting sparse tensor + CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage)) + << "Error input tensors are not dense for custom operator '" << name_str << "'"; + // set outputs as dense + return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); + }; + + // FGradient register lambda + auto grad_reg = [=](const nnvm::NodePtr& n, const std::vector& ograds) { + // copy gradients first + std::vector heads(ograds.begin(), ograds.end()); + // copy inputs second + for (auto& h : n->inputs) { + heads.push_back(h); + } + // copy outputs last + uint32_t n_out = n->num_outputs(); + for (uint32_t i = 0; i < n_out; ++i) { + heads.emplace_back(n, i, 0); + } + std::string grad_name = "_backward_" + name_str; + return mxnet::op::MakeGradNode(grad_name.c_str(), n, heads, n->attrs.dict); + }; + + auto resc_req = [=](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }; + + // library author should implement and return a 'state' which points to an instance + // in lambda we create OpStatePtr using the returned 'state' + auto create_opstate = [=] (const NodeAttrs& attrs, + Context ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + // convert subgraph symbol from node attributes to char* + std::string subgraph_json; + if (!attrs.subgraphs.empty()) { + nnvm::Graph g; + g.outputs = attrs.subgraphs[0].get()->outputs; + subgraph_json = nnvm::pass::SaveJSON(g); + attr_keys.push_back(SUBGRAPH_SYM_JSON); + attr_vals.push_back(subgraph_json.c_str()); + } + + // create a pointer to hold custom op state object + void* state_op_inst = nullptr; + CHECK(callCreateOpState(create_opstate_fp, attr_keys.data(), attr_vals.data(), + attr_keys.size(), &state_op_inst)) + << "Error calling CreateOpState for custom operator '" << name_str << "'"; + + CHECK(state_op_inst != nullptr) + << "Error custom library failed to create stateful operator '" << name_str << "'"; + + CustomStatefulOp* state_op = reinterpret_cast(state_op_inst); + return OpStatePtr::Create(state_op); + }; + + // stateful forward and backward + auto fstateful_lambda = [=](bool is_forward, + const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_data, out_data; + std::vector in_shapes, out_shapes; + std::vector in_dims, out_dims; + std::vector in_types, out_types; + + // convert input tensors to constituent parts + for (size_t i = 0; i < inputs.size(); i++) { + in_data.push_back(inputs[i].data().dptr_); + in_shapes.push_back(inputs[i].shape().data()); + in_dims.push_back(inputs[i].shape().ndim()); + in_types.push_back(inputs[i].dtype()); + } + + // convert output tensors to constituent parts + for (size_t i = 0; i < outputs.size(); i++) { + out_data.push_back(outputs[i].data().dptr_); + out_shapes.push_back(outputs[i].shape().data()); + out_dims.push_back(outputs[i].shape().ndim()); + out_types.push_back(outputs[i].dtype()); + } + + // get memory resource + const Resource &resource = ctx.requested[0]; + mshadow::Stream *cpu_stream = ctx.get_stream(); + + // create lambda that captures stream & resource objects + // this temp workspace holds memory allocated by custom library via OpResource + auto cpu_alloc = [&](int size) { + mshadow::Tensor data = + resource.get_space_typed(mshadow::Shape1(size), cpu_stream); + return data.dptr_; + }; + + // create lambda without captures so that we can cast it to function pointer + // this needs to be a lambda function so that we can do the decltype cast + typedef decltype(cpu_alloc) alloc_type; + auto cpu_malloc = [](void* _cpu_alloc, int size) { + // cast the void* argument to the type for the cpu_alloc lambda function + alloc_type* cpualloc = static_cast(_cpu_alloc); + // call cpu_alloc to actually allocate memory and get the pointer + void* ptr = (*cpualloc)(size); + return ptr; + }; + + // retrieve op state object created from CreateOpState + CustomStatefulOpWrapper& op = state_ptr.get_state(); + CustomStatefulOp* state_op_inst = op.get_instance(); + CHECK(state_op_inst != nullptr) + << "Error MXNet cannot load custom stateful operator'" << name_str << "'"; + + // call fcompute function + CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(), in_dims.data(), + in_data.data(), in_types.data(), in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), + out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc)) + << "Error calling FStatefulCompute for custom operator '" << name_str << "'"; + }; + + auto fstateful_forward = [=](const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + fstateful_lambda(true, state_ptr, ctx, inputs, req, outputs); + }; + + auto fstateful_backward = [=](const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + fstateful_lambda(false, state_ptr, ctx, inputs, req, outputs); + }; + + // check if operator is already registered + const nnvm::Op *regOpPtr = dmlc::Registry::Get()->Find(name); + nnvm::Op ®Op = dmlc::Registry::Get()->__REGISTER_OR_GET__(name); + regOp.set_attr_parser(attr_parser); + regOp.set_num_inputs(num_inputs); + regOp.set_num_outputs(num_outputs); + int plevel = 10; + if (regOpPtr != nullptr) { + // overwrite registration of existing op with custom op + regOp.arguments.clear(); + // set attribute with higher plevel (11) to allow re-registering once + // TODO(samskalicky): enable constant overwriting of registertion multiple times + plevel++; + } + regOp.set_attr("FInferType", infer_type, plevel); + regOp.set_attr("FInferShape", infer_shape, plevel); + regOp.set_attr("FInferStorageType", infer_storage_type, plevel); + regOp.set_attr("FResourceRequest", resc_req, plevel); + // optionally add stateful forward + if (create_opstate_fp != nullptr) { + regOp.set_attr("FCreateOpState", create_opstate, plevel); + regOp.set_attr("FStatefulComputeEx", + fstateful_forward, plevel); + } else { + regOp.set_attr("FComputeEx", forward_lambda, plevel); + } + // optionally add fmutate inputs if user specified a function + if (mutate_fp != nullptr) + regOp.set_attr("FMutateInputs", mutate_inputs, plevel); + // optionally add fgradient if user specified a function + if (fgrad_fp != nullptr || create_opstate_fp != nullptr) { + regOp.set_attr("FGradient", grad_reg, plevel); + std::string grad_name = "_backward_" + name_str; + nnvm::Op &gradOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(grad_name); + gradOp.set_attr("TIsBackward", true, plevel); + gradOp.set_attr_parser(attr_parser); + gradOp.set_num_inputs(num_inouts); + gradOp.set_num_outputs(num_inputs); + gradOp.set_attr("FInferStorageType", infer_storage_type, plevel); + gradOp.set_attr("FResourceRequest", resc_req, plevel); + if (create_opstate_fp != nullptr) { + gradOp.set_attr("TIsLayerOpBackward", true, plevel); + gradOp.set_attr("FStatefulComputeEx", + fstateful_backward, plevel); + } else { + gradOp.set_attr("FComputeEx", backward_lambda, plevel); + } + } + regOp.add_argument("data", "NDArray[]", "Source inputs"); + } API_END(); } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 075261300eff..efa55d2c1cde 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -47,7 +47,7 @@ from test_gluon_gpu import _test_bulking from test_contrib_operator import test_multibox_target_op from test_tvm_op import * -from test_library_loading import * +from test_extensions import * from test_contrib_optimizer import test_adamw set_default_context(mx.gpu(0)) diff --git a/tests/python/unittest/test_library_loading.py b/tests/python/unittest/test_extensions.py similarity index 50% rename from tests/python/unittest/test_library_loading.py rename to tests/python/unittest/test_extensions.py index 29b99dacdbe1..cc7858dce0fd 100644 --- a/tests/python/unittest/test_library_loading.py +++ b/tests/python/unittest/test_extensions.py @@ -21,15 +21,16 @@ import platform import unittest import mxnet as mx +import numpy as np from mxnet.base import MXNetError -from mxnet.test_utils import download, is_cd_run +from mxnet.test_utils import download, is_cd_run, assert_almost_equal def check_platform(): return platform.machine() not in ['x86_64', 'AMD64'] @unittest.skipIf(check_platform(), "not all machine types supported") @unittest.skipIf(is_cd_run(), "continuous delivery run - ignoring test") -def test_library_loading(): +def test_custom_op(): if (os.name=='posix'): lib = 'libsample_lib.so' if os.path.exists(lib): @@ -47,3 +48,39 @@ def test_library_loading(): fname = os.path.abspath(fname) mx.library.load(fname) + + # test simple 2D gemm custom op loaded from sample library + s = mx.sym.Variable('s') + t = mx.sym.Variable('t') + c = mx.sym.my_gemm(s,t) + d = mx.sym.state_gemm(s,t) + base = mx.sym.linalg.gemm2(s,t) # baseline + + dim_n, dim_k, dim_m = tuple(np.random.randint(1, 5, size=3)) + + mat1 = mx.nd.random.uniform(-10, 10, shape=(dim_n, dim_k), ctx=mx.cpu()) + mat2 = mx.nd.random.uniform(-10, 10, shape=(dim_k, dim_m), ctx=mx.cpu()) + + in_grad1 = [mx.nd.empty((dim_n,dim_k),ctx=mx.cpu()),mx.nd.empty((dim_k,dim_m),ctx=mx.cpu())] + in_grad2 = [mx.nd.empty((dim_n,dim_k),ctx=mx.cpu()),mx.nd.empty((dim_k,dim_m),ctx=mx.cpu())] + in_grad_base = [mx.nd.empty((dim_n,dim_k),ctx=mx.cpu()),mx.nd.empty((dim_k,dim_m),ctx=mx.cpu())] + + exe1 = c.bind(ctx=mx.cpu(),args={'s':mat1,'t':mat2},args_grad=in_grad1) + exe2 = d.bind(ctx=mx.cpu(),args={'s':mat1,'t':mat2},args_grad=in_grad2) + exe_base = base.bind(ctx=mx.cpu(),args={'s':mat1,'t':mat2},args_grad=in_grad_base) + + out1 = exe1.forward() + out2 = exe2.forward() + out2 = exe2.forward() # stateful + out_base = exe_base.forward() + + assert_almost_equal(out_base[0].asnumpy(), out1[0].asnumpy(), rtol=1e-3, atol=1e-3) + assert_almost_equal(out_base[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3) + + out_grad = mx.nd.ones((dim_n, dim_m), ctx=mx.cpu()) + exe1.backward([out_grad]) + exe2.backward([out_grad]) + exe_base.backward([out_grad]) + + assert_almost_equal(in_grad_base[0].asnumpy(), in_grad1[0].asnumpy(), rtol=1e-3, atol=1e-3) + assert_almost_equal(in_grad_base[0].asnumpy(), in_grad2[0].asnumpy(), rtol=1e-3, atol=1e-3)