Skip to content

Commit

Permalink
Support extra inputs for subgraph ops (apache#18779)
Browse files Browse the repository at this point in the history
Support additional inputs to custom subgraph ops that are not direct dependencies to ops in the subgraph. This will enable various use cases: custom control flow ops, custom ops that maintain a state that should be saved/loaded, etc.

Highlights:

* Added test that uses a graph pass (addInputPass) to add a new custom input to the subgraph op

* Added new optional argument (clear) to hybridize & optimize_for APIs in Gluon Block to enable multiple optimizations

* refactored lib_api.h JSON utilities

* added new Graph data structure utilities to simplify custom graph passes

* refactored custom op registration

* enhanced custom subgraph op to support additional inputs to subgraph op that is not an input to ops in the subgraph

* updated subgraph & graph pass READMEs

* Added error messaging from external library
  • Loading branch information
samskalicky authored Aug 14, 2020
1 parent 86e96dc commit daf8b43
Show file tree
Hide file tree
Showing 16 changed files with 1,496 additions and 844 deletions.
4 changes: 3 additions & 1 deletion example/extensions/lib_api/init_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
16 changes: 8 additions & 8 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include "lib_api.h"

using namespace mxnet::ext;

// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
Expand Down Expand Up @@ -128,12 +130,12 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int> *outtypes) {
// validate inputs
if (intypes->size() != 2) {
std::cout << "Expected 2 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 2 inputs to inferType";
return MX_FAIL;
}
for (unsigned i = 0; i < intypes->size(); i++) {
if (intypes->at(i) != kFloat32) {
std::cout << "Expected input " << i << " to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input " << i << " to have float32 type";
return MX_FAIL;
}
}
Expand All @@ -147,11 +149,11 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 2) {
std::cout << "Expected 2 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 2 inputs to inferShape";
return MX_FAIL;
}
if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) {
std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 2D matrices for both inputs to inferShape";
return MX_FAIL;
}

Expand All @@ -160,7 +162,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
unsigned kk = inshapes->at(1)[0];
unsigned m = inshapes->at(1)[1];
if (k != kk) {
std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl;
MX_ERROR_MSG << "Exected first input axis 1 equals to second input axis 0";
return MX_FAIL;
}

Expand Down Expand Up @@ -196,8 +198,6 @@ class MyStatefulGemm : public CustomStatefulOp {
return backward(attrs_, inputs, outputs, op_res);
}

~MyStatefulGemm() = default;

private:
int count;
const std::unordered_map<std::string, std::string> attrs_;
Expand Down Expand Up @@ -231,7 +231,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
4 changes: 3 additions & 1 deletion example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
Expand Down Expand Up @@ -263,7 +265,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
22 changes: 12 additions & 10 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down Expand Up @@ -71,11 +73,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
// The data types and storage types of inputs and outputs should be the same.
if(inputs->at(0).dtype != outputs->at(0).dtype ||
inputs->at(0).stype != outputs->at(0).stype) {
std::cout << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype << std::endl;
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}

Expand All @@ -102,11 +104,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
std::cout << "Expected 1 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
std::cout << "Expected input to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}

Expand All @@ -118,7 +120,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kCSRStorage) {
std::cout << "Expected storage type is kCSRStorage" << std::endl;
MX_ERROR_MSG << "Expected storage type is kCSRStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
Expand All @@ -130,7 +132,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
std::cout << "Expected 1 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}

Expand Down Expand Up @@ -195,7 +197,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
22 changes: 12 additions & 10 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down Expand Up @@ -74,11 +76,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
// The data types and storage types of inputs and outputs should be the same.
if(inputs->at(0).dtype != outputs->at(0).dtype ||
inputs->at(0).stype != outputs->at(0).stype) {
std::cout << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype << std::endl;
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}
transpose(inputs->at(0), outputs->at(0), res);
Expand All @@ -104,11 +106,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
std::cout << "Expected 1 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
std::cout << "Expected input to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}

Expand All @@ -120,7 +122,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kRowSparseStorage) {
std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
MX_ERROR_MSG << "Expected storage type is kRowSparseStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
Expand All @@ -132,7 +134,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
std::cout << "Expected 1 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}

Expand Down Expand Up @@ -197,7 +199,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
Loading

0 comments on commit daf8b43

Please sign in to comment.