Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhancements for custom subgraph op #17194

Merged
merged 5 commits into from
Jan 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions example/extensions/lib_custom_op/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,6 @@ MXReturnValue parseAttrs(std::map<std::string, std::string> attrs,
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
outshapes[0] = inshapes[0];
return MX_SUCCESS;
}

class MyStatefulOp : public CustomStatefulOp {
public:
explicit MyStatefulOp(std::string sym) : subgraph_sym(sym) {}
Expand Down Expand Up @@ -98,8 +84,7 @@ MXReturnValue createOpState(std::map<std::string, std::string> attrs,

REGISTER_OP(_custom_subgraph_op)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setIsSubgraphOp()
.setCreateOpState(createOpState);

MXReturnValue initialize(int version) {
Expand Down
13 changes: 10 additions & 3 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ class CustomOp {
public:
explicit CustomOp(const char* op_name) : name(op_name),
forward(NULL), backward(NULL), parse_attrs(NULL), infer_type(NULL),
infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL) {}
infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL),
isSGop(false) {}
~CustomOp() {}
CustomOp& setForward(fcomp_t fcomp) {
forward = fcomp;
Expand Down Expand Up @@ -571,6 +572,10 @@ class CustomOp {
create_opstate = func;
return *this;
}
CustomOp& setIsSubgraphOp() {
isSGop = true;
return *this;
}

/*! \brief operator name */
const char* name;
Expand All @@ -582,6 +587,7 @@ class CustomOp {
inferShape_t infer_shape;
mutateInputs_t mutate_inputs;
createOpState_t create_opstate;
bool isSGop;
};

/*!
Expand Down Expand Up @@ -658,7 +664,7 @@ typedef int (*opRegSize_t)(void);
typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*,
parseAttrs_t*, inferType_t*,
inferShape_t*, mutateInputs_t*,
createOpState_t*);
createOpState_t*, bool*);

#define MXLIB_OPCALLFREE_STR "_opCallFree"
typedef int (*opCallFree_t)(void*);
Expand Down Expand Up @@ -737,7 +743,7 @@ extern "C" {
_opRegGet(int idx, const char** name, fcomp_t* fcomp, fcomp_t* fgrad,
parseAttrs_t* parse, inferType_t* type,
inferShape_t* shape, mutateInputs_t* mutate,
createOpState_t* create_op) {
createOpState_t* create_op, bool *isSGop) {
CustomOp op = Registry<CustomOp>::get()->get(idx);
*name = op.name;
*fcomp = op.forward;
Expand All @@ -747,6 +753,7 @@ extern "C" {
*shape = op.infer_shape;
*mutate = op.mutate_inputs;
*create_op = op.create_opstate;
*isSGop = op.isSGop;
}

/*! \brief calls free from the external library for library allocated arrays */
Expand Down
51 changes: 37 additions & 14 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"
#include "../operator/tensor/matrix_op-inl.h"
#include "../operator/tvmop/op_module.h"
#include "../common/utils.h"
Expand Down Expand Up @@ -162,20 +163,27 @@ int MXLoadLib(const char *path) {
fcomp_t fgrad_fp = nullptr;
mutateInputs_t mutate_fp = nullptr;
createOpState_t create_opstate_fp = nullptr;
bool isSubgraphOp = false;

// get custom operator implemenation from the dynamic library
opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp,
&mutate_fp, &create_opstate_fp);
&mutate_fp, &create_opstate_fp, &isSubgraphOp);

// validate custom operator functions from the dynamic library
CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name
<< "' custom op, Forward or CreateOpState function was not set.";
CHECK(parse_fp != nullptr) << "Error loading '" << name
<< "' custom op, ParseAttrs function was not set.";
CHECK(type_fp != nullptr) << "Error loading '" << name
<< "' custom op, ParseAttrs function was not set.";
if (!isSubgraphOp) {
// validate custom operator functions from the dynamic library
CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name
<< "' custom op, Forward or CreateOpState function was not set.";
CHECK(type_fp != nullptr) << "Error loading '" << name
<< "' custom op, InferType function was not set.";
CHECK(shape_fp != nullptr) << "Error loading '" << name
CHECK(shape_fp != nullptr) << "Error loading '" << name
<< "' custom op, InferShape function was not set.";
} else {
// validate custom operator functions from the dynamic library
CHECK(create_opstate_fp != nullptr) << "Error loading '" << name
<< "' custom subgraph op, CreateOpState function was not set.";
}

LOG(INFO) << "\tOp[" << i << "] " << name;
std::string name_str(name);
Expand Down Expand Up @@ -646,10 +654,28 @@ int MXLoadLib(const char *path) {
// TODO(samskalicky): enable constant overwriting of registertion multiple times
plevel++;
}
regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
if (!isSubgraphOp) {
regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType", infer_storage_type, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
// optionally add fmutate inputs if user specified a function
if (mutate_fp != nullptr)
regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
} else {
using namespace mxnet::op;
regOp.set_attr<nnvm::FInferType>("FInferType",
DefaultSubgraphOpType, plevel);
regOp.set_attr<mxnet::FInferShape>("FInferShape",
DefaultSubgraphOpShape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType",
DefaultSubgraphOpStorageType, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest",
DefaultSubgraphOpResourceRequest, plevel);
regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check if (mutate_fp != nullptr) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, we'll only use default functions when the user sets the setIsSubgraphOp flag. Later we can revisit if the user wants to use custom functions mixed with default functions

DefaultSubgraphOpMutableInputs, plevel);
}

// optionally add stateful forward
if (create_opstate_fp != nullptr) {
regOp.set_attr<FCreateOpState>("FCreateOpState", create_opstate, plevel);
Expand All @@ -658,9 +684,6 @@ int MXLoadLib(const char *path) {
} else {
regOp.set_attr<FComputeEx>("FComputeEx<cpu>", forward_lambda, plevel);
}
// optionally add fmutate inputs if user specified a function
if (mutate_fp != nullptr)
regOp.set_attr<nnvm::FMutateInputs>("FMutateInputs", mutate_inputs, plevel);
// optionally add fgradient if user specified a function
if (fgrad_fp != nullptr || create_opstate_fp != nullptr) {
regOp.set_attr<nnvm::FGradient>("FGradient", grad_reg, plevel);
Expand Down