Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Enhancements for custom subgraph op (#17194)
Browse files Browse the repository at this point in the history
* initial commit

* added flag on user library side in example

* fixed whitespace

* simplified API
  • Loading branch information
samskalicky authored and zachgk committed Jan 2, 2020
1 parent b38816d commit 77c7c3a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 33 deletions.
17 changes: 1 addition & 16 deletions example/extensions/lib_custom_op/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,6 @@ MXReturnValue parseAttrs(std::map<std::string, std::string> attrs,
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
outshapes[0] = inshapes[0];
return MX_SUCCESS;
}

class MyStatefulOp : public CustomStatefulOp {
public:
explicit MyStatefulOp(std::string sym) : subgraph_sym(sym) {}
Expand Down Expand Up @@ -98,8 +84,7 @@ MXReturnValue createOpState(std::map<std::string, std::string> attrs,

REGISTER_OP(_custom_subgraph_op)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setIsSubgraphOp()
.setCreateOpState(createOpState);

MXReturnValue initialize(int version) {
Expand Down
13 changes: 10 additions & 3 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ class CustomOp {
public:
explicit CustomOp(const char* op_name) : name(op_name),
forward(NULL), backward(NULL), parse_attrs(NULL), infer_type(NULL),
infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL) {}
infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL),
isSGop(false) {}
~CustomOp() {}
CustomOp& setForward(fcomp_t fcomp) {
forward = fcomp;
Expand Down Expand Up @@ -571,6 +572,10 @@ class CustomOp {
create_opstate = func;
return *this;
}
CustomOp& setIsSubgraphOp() {
isSGop = true;
return *this;
}

/*! \brief operator name */
const char* name;
Expand All @@ -582,6 +587,7 @@ class CustomOp {
inferShape_t infer_shape;
mutateInputs_t mutate_inputs;
createOpState_t create_opstate;
bool isSGop;
};

/*!
Expand Down Expand Up @@ -658,7 +664,7 @@ typedef int (*opRegSize_t)(void);
typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*,
parseAttrs_t*, inferType_t*,
inferShape_t*, mutateInputs_t*,
createOpState_t*);
createOpState_t*, bool*);

#define MXLIB_OPCALLFREE_STR "_opCallFree"
typedef int (*opCallFree_t)(void*);
Expand Down Expand Up @@ -737,7 +743,7 @@ extern "C" {
_opRegGet(int idx, const char** name, fcomp_t* fcomp, fcomp_t* fgrad,
parseAttrs_t* parse, inferType_t* type,
inferShape_t* shape, mutateInputs_t* mutate,
createOpState_t* create_op) {
createOpState_t* create_op, bool *isSGop) {
CustomOp op = Registry<CustomOp>::get()->get(idx);
*name = op.name;
*fcomp = op.forward;
Expand All @@ -747,6 +753,7 @@ extern "C" {
*shape = op.infer_shape;
*mutate = op.mutate_inputs;
*create_op = op.create_opstate;
*isSGop = op.isSGop;
}

/*! \brief calls free from the external library for library allocated arrays */
Expand Down
51 changes: 37 additions & 14 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"
#include "../operator/tensor/matrix_op-inl.h"
#include "../operator/tvmop/op_module.h"
#include "../common/utils.h"
Expand Down Expand Up @@ -162,20 +163,27 @@ int MXLoadLib(const char *path) {
fcomp_t fgrad_fp = nullptr;
mutateInputs_t mutate_fp = nullptr;
createOpState_t create_opstate_fp = nullptr;
bool isSubgraphOp = false;

// get custom operator implemenation from the dynamic library
opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp,
&mutate_fp, &create_opstate_fp);
&mutate_fp, &create_opstate_fp, &isSubgraphOp);

// validate custom operator functions from the dynamic library
CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name
<< "' custom op, Forward or CreateOpState function was not set.";
CHECK(parse_fp != nullptr) << "Error loading '" << name
<< "' custom op, ParseAttrs function was not set.";
CHECK(type_fp != nullptr) << "Error loading '" << name
<< "' custom op, ParseAttrs function was not set.";
if (!isSubgraphOp) {
// validate custom operator functions from the dynamic library
CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name
<< "' custom op, Forward or CreateOpState function was not set.";
CHECK(type_fp != nullptr) << "Error loading '" << name
<< "' custom op, InferType function was not set.";
CHECK(shape_fp != nullptr) << "Error loading '" << name
CHECK(shape_fp != nullptr) << "Error loading '" << name
<< "' custom op, InferShape function was not set.";
} else {
// validate custom operator functions from the dynamic library
CHECK(create_opstate_fp != nullptr) << "Error loading '" << name
<< "' custom subgraph op, CreateOpState function was not set.";
}

LOG(INFO) << "\tOp[" << i << "] " << name;
std::string name_str(name);
Expand Down Expand Up @@ -646,10 +654,28 @@ int MXLoadLib(const char *path) {
// TODO(samskalicky): enable constant overwriting of registertion multiple times
plevel++;
}
regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
if (!isSubgraphOp) {
regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
// optionally add fmutate inputs if user specified a function
if (mutate_fp != nullptr)
regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
} else {
using namespace mxnet::op;
regOp.set_attr<nnvm::FInferType>("FInferType",
DefaultSubgraphOpType, plevel);
regOp.set_attr<mxnet::FInferShape>("FInferShape",
DefaultSubgraphOpShape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType",
DefaultSubgraphOpStorageType, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest",
DefaultSubgraphOpResourceRequest, plevel);
regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs",
DefaultSubgraphOpMutableInputs, plevel);
}

// optionally add stateful forward
if (create_opstate_fp != nullptr) {
regOp.set_attr<FCreateOpState>("FCreateOpState", create_opstate, plevel);
Expand All @@ -658,9 +684,6 @@ int MXLoadLib(const char *path) {
} else {
regOp.set_attr<FComputeEx>("FComputeEx<cpu>", forward_lambda, plevel);
}
// optionally add fmutate inputs if user specified a function
if (mutate_fp != nullptr)
regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
// optionally add fgradient if user specified a function
if (fgrad_fp != nullptr || create_opstate_fp != nullptr) {
regOp.set_attr<nnvm::FGradient>("FGradient", grad_reg, plevel);
Expand Down

0 comments on commit 77c7c3a

Please sign in to comment.