diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 1d42cf7c18f8..38195bd62ffa 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -127,7 +127,7 @@ def _cut_subgraph(subg): # This construct a subgraph for given output nodes. # If an output node is one of the input nodes, we call identity to make sure # that outputs nodes are different from input nodes. -def _construct_subgraph(sym_out, sym_states): +def _construct_subgraph(sym_out, sym_states, name): sym_out = _as_list(sym_out) sym_states = _as_list(sym_states) all_outputs = [] @@ -137,18 +137,16 @@ def _construct_subgraph(sym_out, sym_states): flat_out = [] all_input_names = g.list_inputs() - output_names = [o.name for o in sym_out] + output_names = {o.name for o in sym_out} for o in sym_out: - if o.name in all_input_names: + if o.name in all_input_names or o.list_attr().get("__subgraph_name__", "") != name: flat_out.append(symbol.op.identity(o)) else: flat_out.append(o) for s in sym_states: - if s.name in all_input_names or s.name in output_names: - # There is a problem if the outputs are the same as the inputs - # or the first output. By calling identity, we can make sure that - # all symbols will refer to different NDArrays. + if s.name in all_input_names or s.name in output_names or \ + s.list_attr().get("__subgraph_name__", "") != name: flat_out.append(symbol.op.identity(s)) else: flat_out.append(s) @@ -256,7 +254,7 @@ def check_data(inputs, in_type, msg): num_out_data = len(sym_out) num_states = len(sym_states) num_outputs = num_out_data + num_states - g = _construct_subgraph(sym_out, sym_states) + g = _construct_subgraph(sym_out, sym_states, name) input_syms = _get_graph_inputs(g) cut_syms = _cut_subgraph(g) @@ -469,9 +467,12 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name): num_outputs = len(outputs) + len(final_state) # nnvm cut-graph does not allow inputs and outputs overlap # so we calculate the name of inputs, and copy outputs once it overlaps with inputs - all_input_names = symbol.Group(outputs + final_state).list_inputs() - make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x # group all outputs of graph_func + all_input_names = symbol.Group(outputs + final_state).list_inputs() + in_input = lambda x: x.name in all_input_names + in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name + make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \ + else x graph = symbol.Group(list(map(make_identity, outputs + final_state))) return graph, num_out_data, num_outputs @@ -627,9 +628,12 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name): num_outputs = len(outputs) # nnvm cut-graph does not allow inputs and outputs overlap # so we calculate the name of inputs, and copy outputs once it overlaps with inputs - all_input_names = symbol.Group(outputs).list_inputs() - make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x # group all outputs of graph_func + all_input_names = symbol.Group(outputs).list_inputs() + in_input = lambda x: x.name in all_input_names + in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name + make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \ + else x graph = symbol.Group(list(map(make_identity, outputs))) return graph, num_outputs diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index c27a59a67c6e..35ecec7e11f6 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -372,13 +372,13 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, // a subgraph. API_BEGIN(); nnvm::Symbol *s = static_cast(sym); - std::string subg_attr = "__subgraph_name__"; + const std::string subg_attr = "__subgraph_name__"; auto out_node = s->outputs[0].node; auto it = out_node->attrs.dict.find(subg_attr); if (it != out_node->attrs.dict.end()) { - std::string subg_name = it->second; + const std::string &subg_name = it->second; std::vector input_entries; - DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries] + DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries] (nnvm::NodePtr n) { // If the node itself isn't in the subgraph, we ignore it. auto it = n->attrs.dict.find(subg_attr); diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 7c1beccb0288..d6b6703ddd58 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -1225,6 +1225,9 @@ static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs, CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf)); return ret; }; + for (const dim_t &cond_in : params.cond_input_locs) { + (*out_attrs)[cond_in] = kDefaultStorage; + } bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs); bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs); return succ_0 && succ_1; diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index f1188b53d814..a4b794c95951 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1664,6 +1664,107 @@ def test_foreach_rnn(): check_foreach_rnn(cell_type, num_states) +@with_seed() +def test_cut_subgraph_foreach(): + class TestLayer(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestLayer, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, inputs, states): + def step1(data, states): + return data + 1, states + out1, states1 = F.contrib.foreach(step1, inputs, states) + out2, states2 = F.contrib.foreach(step1, out1, states) + def step2(data, states): + return data + states[0], states1 + out, states = F.contrib.foreach(step2, out2, states) + return out + + data = mx.nd.normal(loc=0, scale=1, shape=(5, 10)) + states = mx.nd.normal(loc=0, scale=1, shape=(10)) + layer = TestLayer() + layer.initialize(ctx=default_context()) + res1 = layer(data, [states]) + + with mx.autograd.record(): + res1 = layer(data, [states]) + + layer = TestLayer() + layer.initialize(ctx=default_context()) + layer.hybridize() + res2 = layer(data, [states]) + + with mx.autograd.record(): + res2 = layer(data, [states]) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + + +@with_seed() +def test_cut_subgraph_while_loop(): + class TestLayer(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestLayer, self).__init__(prefix=prefix, params=params) + def hybrid_forward(self, F, data): + out1, data1 = F.contrib.while_loop( + cond=lambda i: i <= 5, + func=lambda i: (None, (i + 1, )), + loop_vars=(data, ), + max_iterations=10, + ) + out2, data2 = F.contrib.while_loop( + cond=lambda i: data1[0], + func=lambda i: (None, (i + 1, )), + loop_vars=data1[0], + max_iterations=10, + ) + return data2[0] + data = mx.nd.normal(loc=0, scale=1, shape=(1, )) + layer = TestLayer() + layer.initialize(ctx=default_context()) + res1 = layer(data) + with mx.autograd.record(): + res1 = layer(data) + layer = TestLayer() + layer.initialize(ctx=default_context()) + layer.hybridize() + res2 = layer(data) + with mx.autograd.record(): + res2 = layer(data) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + + +@with_seed() +def test_cut_subgraph_cond(): + class TestLayer(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestLayer, self).__init__(prefix=prefix, params=params) + def hybrid_forward(self, F, data): + (data1, ) = F.contrib.cond( + data > 0.5, + then_func=lambda: data * 2, + else_func=lambda: data * 3, + ) + (data2, ) = F.contrib.cond( + data1 > 0.5, + then_func=lambda: data1 * 2, + else_func=lambda: data1 * 3, + ) + return data2 + data = mx.nd.normal(loc=0, scale=1, shape=(1, )) + layer = TestLayer() + layer.initialize(ctx=default_context()) + res1 = layer(data) + with mx.autograd.record(): + res1 = layer(data) + layer = TestLayer() + layer.initialize(ctx=default_context()) + layer.hybridize() + res2 = layer(data) + with mx.autograd.record(): + res2 = layer(data) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + + if __name__ == '__main__': import nose nose.runmodule()