Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Init commit for Gluon hybridblocks partition(sample test included)
Browse files Browse the repository at this point in the history
  • Loading branch information
guanxinq committed Jan 17, 2020
1 parent 2c50232 commit 3a64283
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
15 changes: 15 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,26 @@ MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
const uint32_t num_ops,
const char** op_names);

/*!
* \brief Given a subgraph property name, use the provided op names
* as the op_names attribute for that subgraph property, instead of
* the predefined one. This is only for the purpose of testing.
* Compared to MXSetSubgraphPropertyOpNames(), this API will add
* op_names to the backend property.
*/
MXNET_DLL int MXSetSubgraphPropertyOpNamesV2(const char* prop_name,
const uint32_t num_ops,
const char** op_names);
/*!
* \brief Given a subgraph property name, delete the op name set
* in the SubgraphPropertyOpNameSet.
*/
MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
/*!
* \brief Given a subgraph property name, remove op_names attribute of
* the in the SubgraphBackend property.
*/
MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name);

#ifdef __cplusplus
}
Expand Down
7 changes: 6 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,9 @@ def _get_graph(self, *args):

def _build_cache(self, *args):
data, out = self._get_graph(*args)
if self._backend:
# To do: pass in all arguments
out = out.optimize_for(self._backend)
data_names = {data.name: i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()
Expand Down Expand Up @@ -1040,7 +1043,9 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, **kwargs):
def hybridize(self, active=True, backend='', optargs={}, **kwargs):
self._backend = backend
self._optargs = optargs
self._active = active
self._flags = list(kwargs.items())
self._clear_cached_op()
Expand Down
28 changes: 28 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,36 @@ int MXSetSubgraphPropertyOpNames(const char* prop_name,
API_END();
}

int MXSetSubgraphPropertyOpNamesV2(const char* prop_name,
const uint32_t num_ops,
const char** op_names) {
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
auto& backend =
mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto& property : subgraph_prop_list) {
property->SetAttr("op_names", op_name_set);
}
API_END();
}

int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
API_BEGIN();
mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
API_END();
}

int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) {
API_BEGIN();
auto& backend =
mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto& property : subgraph_prop_list) {
property->RemoveAttr("op_names");
}
API_END();
}
40 changes: 40 additions & 0 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from mxnet.symbol import Symbol
import numpy as np
from mxnet.test_utils import assert_almost_equal
from mxnet.gluon import nn
from mxnet import nd


def _test_subgraph_exe(subgraph_backend):
Expand Down Expand Up @@ -364,6 +366,44 @@ def test_subgraph_exe():
def test_subgraph_v2_exe():
_test_subgraph_exe('default_v2')

# Here is just a temporary sample test.
# To do: refactor tests and add more test for gluon.
class Net(nn.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
with self.name_scope():
self.fc1 = nn.Dense(256)
self.fc2 = nn.Dense(128)
self.fc3 = nn.Dense(2)

def hybrid_forward(self, F, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)

def test_subgraph_gluon():
#Create networks and initialze.
net = Net()
net.initialize()
x = nd.random.normal(shape=(1, 512))

# Call hybridize and run inference.
net.hybridize()
outputs1 = net(x)

# Call hybridize with default backend and run inference.
net.hybridize(backend = "default")
op_names = []
check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str("default"), mx_uint(len(op_names)),
c_str_array(op_names)))
outputs2 = net(x)
check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str("default")))

# Compare results.
assert len(outputs1) == len(outputs2)
for i in range(len(outputs1)):
assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 3a64283

Please sign in to comment.