From 0a4e6e0837e36dc9e418c0270097e0444d17186d Mon Sep 17 00:00:00 2001 From: samskalicky Date: Mon, 30 Dec 2019 16:24:02 +0000 Subject: [PATCH 1/4] initial commit --- include/mxnet/lib_api.h | 13 +++++++--- src/c_api/c_api.cc | 53 +++++++++++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 290a63518373..9808d76d9a7c 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -541,7 +541,8 @@ class CustomOp { public: explicit CustomOp(const char* op_name) : name(op_name), forward(NULL), backward(NULL), parse_attrs(NULL), infer_type(NULL), - infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL) {} + infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL), + isSGop(false) {} ~CustomOp() {} CustomOp& setForward(fcomp_t fcomp) { forward = fcomp; @@ -571,6 +572,10 @@ class CustomOp { create_opstate = func; return *this; } + CustomOp& isSubgraphOp(bool state) { + isSGop = state; + return *this; + } /*! \brief operator name */ const char* name; @@ -582,6 +587,7 @@ class CustomOp { inferShape_t infer_shape; mutateInputs_t mutate_inputs; createOpState_t create_opstate; + bool isSGop; }; /*! @@ -658,7 +664,7 @@ typedef int (*opRegSize_t)(void); typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*, parseAttrs_t*, inferType_t*, inferShape_t*, mutateInputs_t*, - createOpState_t*); + createOpState_t*, bool*); #define MXLIB_OPCALLFREE_STR "_opCallFree" typedef int (*opCallFree_t)(void*); @@ -737,7 +743,7 @@ extern "C" { _opRegGet(int idx, const char** name, fcomp_t* fcomp, fcomp_t* fgrad, parseAttrs_t* parse, inferType_t* type, inferShape_t* shape, mutateInputs_t* mutate, - createOpState_t* create_op) { + createOpState_t* create_op, bool *isSGop) { CustomOp op = Registry::get()->get(idx); *name = op.name; *fcomp = op.forward; @@ -747,6 +753,7 @@ extern "C" { *shape = op.infer_shape; *mutate = op.mutate_inputs; *create_op = op.create_opstate; + *isSGop = op.isSGop; } /*! \brief calls free from the external library for library allocated arrays */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f8db501d46f0..59e32b888403 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -51,6 +51,7 @@ #include "./c_api_common.h" #include "../operator/custom/custom-inl.h" #include "../operator/operator_common.h" +#include "../operator/subgraph/common.h" #include "../operator/tensor/matrix_op-inl.h" #include "../operator/tvmop/op_module.h" #include "../common/utils.h" @@ -162,21 +163,28 @@ int MXLoadLib(const char *path) { fcomp_t fgrad_fp = nullptr; mutateInputs_t mutate_fp = nullptr; createOpState_t create_opstate_fp = nullptr; - + bool isSubgraphOp = false; + // get custom operator implemenation from the dynamic library opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp, - &mutate_fp, &create_opstate_fp); + &mutate_fp, &create_opstate_fp, &isSubgraphOp); - // validate custom operator functions from the dynamic library - CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name + if(!isSubgraphOp) { + // validate custom operator functions from the dynamic library + CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name << "' custom op, Forward or CreateOpState function was not set."; - CHECK(parse_fp != nullptr) << "Error loading '" << name + CHECK(parse_fp != nullptr) << "Error loading '" << name << "' custom op, ParseAttrs function was not set."; - CHECK(type_fp != nullptr) << "Error loading '" << name + CHECK(type_fp != nullptr) << "Error loading '" << name << "' custom op, InferType function was not set."; - CHECK(shape_fp != nullptr) << "Error loading '" << name + CHECK(shape_fp != nullptr) << "Error loading '" << name << "' custom op, InferShape function was not set."; - + } else { + // validate custom operator functions from the dynamic library + CHECK(create_opstate_fp != nullptr) << "Error loading '" << name + << "' custom subgraph op, CreateOpState function was not set."; + } + LOG(INFO) << "\tOp[" << i << "] " << name; std::string name_str(name); @@ -646,10 +654,28 @@ int MXLoadLib(const char *path) { // TODO(samskalicky): enable constant overwriting of registertion multiple times plevel++; } - regOp.set_attr("FInferType", infer_type, plevel); - regOp.set_attr("FInferShape", infer_shape, plevel); - regOp.set_attr("FInferStorageType", infer_storage_type, plevel); - regOp.set_attr("FResourceRequest", resc_req, plevel); + if(!isSubgraphOp) { + regOp.set_attr("FInferType", infer_type, plevel); + regOp.set_attr("FInferShape", infer_shape, plevel); + regOp.set_attr("FInferStorageType", infer_storage_type, plevel); + regOp.set_attr("FResourceRequest", resc_req, plevel); + // optionally add fmutate inputs if user specified a function + if (mutate_fp != nullptr) + regOp.set_attr("FMutateInputs", mutate_inputs, plevel); + } else { + using namespace mxnet::op; + regOp.set_attr("FInferType", + DefaultSubgraphOpType, plevel); + regOp.set_attr("FInferShape", + DefaultSubgraphOpShape, plevel); + regOp.set_attr("FInferStorageType", + DefaultSubgraphOpStorageType, plevel); + regOp.set_attr("FResourceRequest", + DefaultSubgraphOpResourceRequest, plevel); + regOp.set_attr("FMutateInputs", + DefaultSubgraphOpMutableInputs, plevel); + } + // optionally add stateful forward if (create_opstate_fp != nullptr) { regOp.set_attr("FCreateOpState", create_opstate, plevel); @@ -658,9 +684,6 @@ int MXLoadLib(const char *path) { } else { regOp.set_attr("FComputeEx", forward_lambda, plevel); } - // optionally add fmutate inputs if user specified a function - if (mutate_fp != nullptr) - regOp.set_attr("FMutateInputs", mutate_inputs, plevel); // optionally add fgradient if user specified a function if (fgrad_fp != nullptr || create_opstate_fp != nullptr) { regOp.set_attr("FGradient", grad_reg, plevel); From cabb4c6a428054a70898a7f2f27f60ca91f90513 Mon Sep 17 00:00:00 2001 From: samskalicky Date: Mon, 30 Dec 2019 16:44:22 +0000 Subject: [PATCH 2/4] added flag on user library side in example --- .../extensions/lib_custom_op/subgraph_lib.cc | 17 +---------------- src/c_api/c_api.cc | 4 ++-- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/example/extensions/lib_custom_op/subgraph_lib.cc b/example/extensions/lib_custom_op/subgraph_lib.cc index 8e7e8833745a..c5d230894d3f 100644 --- a/example/extensions/lib_custom_op/subgraph_lib.cc +++ b/example/extensions/lib_custom_op/subgraph_lib.cc @@ -46,20 +46,6 @@ MXReturnValue parseAttrs(std::map attrs, return MX_SUCCESS; } -MXReturnValue inferType(std::map attrs, - std::vector &intypes, - std::vector &outtypes) { - outtypes[0] = intypes[0]; - return MX_SUCCESS; -} - -MXReturnValue inferShape(std::map attrs, - std::vector> &inshapes, - std::vector> &outshapes) { - outshapes[0] = inshapes[0]; - return MX_SUCCESS; -} - class MyStatefulOp : public CustomStatefulOp { public: explicit MyStatefulOp(std::string sym) : subgraph_sym(sym) {} @@ -98,8 +84,7 @@ MXReturnValue createOpState(std::map attrs, REGISTER_OP(_custom_subgraph_op) .setParseAttrs(parseAttrs) -.setInferType(inferType) -.setInferShape(inferShape) +.isSubgraphOp(true) .setCreateOpState(createOpState); MXReturnValue initialize(int version) { diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 59e32b888403..3bc9656e1ce5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -169,12 +169,12 @@ int MXLoadLib(const char *path) { opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp, &mutate_fp, &create_opstate_fp, &isSubgraphOp); + CHECK(parse_fp != nullptr) << "Error loading '" << name + << "' custom op, ParseAttrs function was not set."; if(!isSubgraphOp) { // validate custom operator functions from the dynamic library CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name << "' custom op, Forward or CreateOpState function was not set."; - CHECK(parse_fp != nullptr) << "Error loading '" << name - << "' custom op, ParseAttrs function was not set."; CHECK(type_fp != nullptr) << "Error loading '" << name << "' custom op, InferType function was not set."; CHECK(shape_fp != nullptr) << "Error loading '" << name From ca3bf8ffabe10f9e1b49836564cb95ec4ea29194 Mon Sep 17 00:00:00 2001 From: samskalicky Date: Tue, 31 Dec 2019 01:21:19 +0000 Subject: [PATCH 3/4] fixed whitespace --- src/c_api/c_api.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3bc9656e1ce5..ca39ef22b20a 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -164,14 +164,14 @@ int MXLoadLib(const char *path) { mutateInputs_t mutate_fp = nullptr; createOpState_t create_opstate_fp = nullptr; bool isSubgraphOp = false; - + // get custom operator implemenation from the dynamic library opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp, &mutate_fp, &create_opstate_fp, &isSubgraphOp); CHECK(parse_fp != nullptr) << "Error loading '" << name << "' custom op, ParseAttrs function was not set."; - if(!isSubgraphOp) { + if (!isSubgraphOp) { // validate custom operator functions from the dynamic library CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name << "' custom op, Forward or CreateOpState function was not set."; @@ -182,9 +182,9 @@ int MXLoadLib(const char *path) { } else { // validate custom operator functions from the dynamic library CHECK(create_opstate_fp != nullptr) << "Error loading '" << name - << "' custom subgraph op, CreateOpState function was not set."; + << "' custom subgraph op, CreateOpState function was not set."; } - + LOG(INFO) << "\tOp[" << i << "] " << name; std::string name_str(name); @@ -654,7 +654,7 @@ int MXLoadLib(const char *path) { // TODO(samskalicky): enable constant overwriting of registertion multiple times plevel++; } - if(!isSubgraphOp) { + if (!isSubgraphOp) { regOp.set_attr("FInferType", infer_type, plevel); regOp.set_attr("FInferShape", infer_shape, plevel); regOp.set_attr("FInferStorageType", infer_storage_type, plevel); @@ -675,7 +675,7 @@ int MXLoadLib(const char *path) { regOp.set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs, plevel); } - + // optionally add stateful forward if (create_opstate_fp != nullptr) { regOp.set_attr("FCreateOpState", create_opstate, plevel); From bb1890925b48fd34d6fc3e4fcbddb00a8d099d53 Mon Sep 17 00:00:00 2001 From: samskalicky Date: Tue, 31 Dec 2019 03:57:58 +0000 Subject: [PATCH 4/4] simplified API --- example/extensions/lib_custom_op/subgraph_lib.cc | 2 +- include/mxnet/lib_api.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example/extensions/lib_custom_op/subgraph_lib.cc b/example/extensions/lib_custom_op/subgraph_lib.cc index c5d230894d3f..27da0fd9a324 100644 --- a/example/extensions/lib_custom_op/subgraph_lib.cc +++ b/example/extensions/lib_custom_op/subgraph_lib.cc @@ -84,7 +84,7 @@ MXReturnValue createOpState(std::map attrs, REGISTER_OP(_custom_subgraph_op) .setParseAttrs(parseAttrs) -.isSubgraphOp(true) +.setIsSubgraphOp() .setCreateOpState(createOpState); MXReturnValue initialize(int version) { diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 9808d76d9a7c..f21e484216ea 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -572,8 +572,8 @@ class CustomOp { create_opstate = func; return *this; } - CustomOp& isSubgraphOp(bool state) { - isSGop = state; + CustomOp& setIsSubgraphOp() { + isSGop = true; return *this; }