From 3a64283e2bb6a11a12437c654ab2e9c073a5b9e7 Mon Sep 17 00:00:00 2001 From: guanxinq Date: Fri, 17 Jan 2020 18:20:02 +0000 Subject: [PATCH] Init commit for Gluon hybridblocks partition(sample test included) --- include/mxnet/c_api_test.h | 15 +++++++++ python/mxnet/gluon/block.py | 7 +++- src/c_api/c_api_test.cc | 28 ++++++++++++++++ tests/python/unittest/test_subgraph_op.py | 40 +++++++++++++++++++++++ 4 files changed, 89 insertions(+), 1 deletion(-) diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h index 8d397d42e5c1..35856e10efa3 100644 --- a/include/mxnet/c_api_test.h +++ b/include/mxnet/c_api_test.h @@ -53,11 +53,26 @@ MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name, const uint32_t num_ops, const char** op_names); +/*! + * \brief Given a subgraph property name, use the provided op names + * as the op_names attribute for that subgraph property, instead of + * the predefined one. This is only for the purpose of testing. + * Compared to MXSetSubgraphPropertyOpNames(), this API will add + * op_names to the backend property. + */ +MXNET_DLL int MXSetSubgraphPropertyOpNamesV2(const char* prop_name, + const uint32_t num_ops, + const char** op_names); /*! * \brief Given a subgraph property name, delete the op name set * in the SubgraphPropertyOpNameSet. */ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name); +/*! + * \brief Given a subgraph property name, remove op_names attribute of + * the in the SubgraphBackend property. + */ +MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name); #ifdef __cplusplus } diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 33679220878a..0becefe82b54 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -932,6 +932,9 @@ def _get_graph(self, *args): def _build_cache(self, *args): data, out = self._get_graph(*args) + if self._backend: + # To do: pass in all arguments + out = out.optimize_for(self._backend) data_names = {data.name: i for i, data in enumerate(data)} params = self.collect_params() input_names = out.list_inputs() @@ -1040,7 +1043,9 @@ def register_child(self, block, name=None): super(HybridBlock, self).register_child(block, name) self._clear_cached_op() - def hybridize(self, active=True, **kwargs): + def hybridize(self, active=True, backend='', optargs={}, **kwargs): + self._backend = backend + self._optargs = optargs self._active = active self._flags = list(kwargs.items()) self._clear_cached_op() diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc index 8d4e432e4f7d..de4fb7dca18e 100644 --- a/src/c_api/c_api_test.cc +++ b/src/c_api/c_api_test.cc @@ -73,8 +73,36 @@ int MXSetSubgraphPropertyOpNames(const char* prop_name, API_END(); } +int MXSetSubgraphPropertyOpNamesV2(const char* prop_name, + const uint32_t num_ops, + const char** op_names) { + API_BEGIN(); + std::unordered_set op_name_set; + for (size_t i = 0; i < num_ops; ++i) { + op_name_set.emplace(op_names[i]); + } + auto& backend = + mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); + const auto& subgraph_prop_list = backend->GetSubgraphProperties(); + for (auto& property : subgraph_prop_list) { + property->SetAttr("op_names", op_name_set); + } + API_END(); +} + int MXRemoveSubgraphPropertyOpNames(const char* prop_name) { API_BEGIN(); mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name); API_END(); } + +int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) { + API_BEGIN(); + auto& backend = + mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); + const auto& subgraph_prop_list = backend->GetSubgraphProperties(); + for (auto& property : subgraph_prop_list) { + property->RemoveAttr("op_names"); + } + API_END(); +} diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index f2d75658aafd..ae5faf9496d8 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -22,6 +22,8 @@ from mxnet.symbol import Symbol import numpy as np from mxnet.test_utils import assert_almost_equal +from mxnet.gluon import nn +from mxnet import nd def _test_subgraph_exe(subgraph_backend): @@ -364,6 +366,44 @@ def test_subgraph_exe(): def test_subgraph_v2_exe(): _test_subgraph_exe('default_v2') +# Here is just a temporary sample test. +# To do: refactor tests and add more test for gluon. +class Net(nn.HybridBlock): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.fc1 = nn.Dense(256) + self.fc2 = nn.Dense(128) + self.fc3 = nn.Dense(2) + + def hybrid_forward(self, F, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + +def test_subgraph_gluon(): + #Create networks and initialze. + net = Net() + net.initialize() + x = nd.random.normal(shape=(1, 512)) + + # Call hybridize and run inference. + net.hybridize() + outputs1 = net(x) + + # Call hybridize with default backend and run inference. + net.hybridize(backend = "default") + op_names = [] + check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str("default"), mx_uint(len(op_names)), + c_str_array(op_names))) + outputs2 = net(x) + check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str("default"))) + + # Compare results. + assert len(outputs1) == len(outputs2) + for i in range(len(outputs1)): + assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) + if __name__ == '__main__': import nose nose.runmodule()