From cbdf2dbe77d43c0c5d6d23c6a96e5a5748a3aef7 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 26 Sep 2016 13:35:26 -0700 Subject: [PATCH] [SYMBOL] Change list_input->list_input_names, add list_input_variables (#59) * [SYMBOL] Change list_input->list_input_names, add list_input_variables * fix --- nnvm/include/nnvm/c_api.h | 19 +++++++++++- nnvm/python/nnvm/symbol.py | 50 ++++++++++++++++++++++++-------- nnvm/src/c_api/c_api_symbolic.cc | 19 ++++++++++++ nnvm/src/core/symbolic.cc | 7 +++-- nnvm/src/pass/gradient.cc | 17 ++++++----- nnvm/tests/python/test_graph.py | 4 +-- nnvm/tests/python/test_symbol.py | 10 ++++--- 7 files changed, 98 insertions(+), 28 deletions(-) diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index a7bf2616ddce..e4e7ac59c5f6 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -205,8 +205,25 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, int recursive_option, nn_uint *out_size, const char*** out); + +/*! + * \brief List inputs variables in the symbol. + * \param symbol the symbol + * \param option The option to list the inputs + * option=0 means list all arguments. + * option=1 means list arguments that are readed only by the graph. + * option=2 means list arguments that are mutated by the graph. + * \param out_size output size + * \param out_sym_array the output array. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, + int option, + nn_uint *out_size, + SymbolHandle** out_sym_array); + /*! - * \brief List inputs in the symbol. + * \brief List input names in the symbol. * \param symbol the symbol * \param option The option to list the inputs * option=0 means list all arguments. diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index bcd34cfb6a6a..79c4b7ec4626 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -108,7 +108,7 @@ def __deepcopy__(self, _=None): def __getitem__(self, index): if isinstance(index, _base.string_types): idx = None - for i, name in enumerate(self.list_outputs()): + for i, name in enumerate(self.list_output_names()): if name == index: if idx is not None: raise ValueError('There are multiple outputs with name \"%s\"' % index) @@ -177,7 +177,40 @@ def get_internals(self): self.handle, _ctypes.byref(handle))) return Symbol(handle=handle) - def list_inputs(self, option='all'): + def _get_list_copt(self, option): + """internal function to get list option""" + if option == 'all': + return _ctypes.c_int(0) + elif option == 'read_only': + return _ctypes.c_int(1) + elif option == 'aux_state': + return _ctypes.c_int(2) + else: + raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}") + + def list_input_variables(self, option='all'): + """List all the input variables in the symbol. + + Parameters + ---------- + option : {'all', 'read_only', 'aux_state'}, optional + The listing option + - 'all' will list all the arguments. + - 'read_only' lists arguments that are readed by the graph. + - 'aux_state' lists arguments that are mutated by the graph as state. + Returns + ------- + vars : list of symbol + List of all the variables + """ + size = _ctypes.c_uint() + sarr = _ctypes.POINTER(_base.SymbolHandle)() + _check_call(_LIB.NNSymbolListInputVariables( + self.handle, self._get_list_copt(option), + _ctypes.byref(size), _ctypes.byref(sarr))) + return [Symbol(_base.SymbolHandle(sarr[i])) for i in range(size.value)] + + def list_input_names(self, option='all'): """List all the inputs in the symbol. Parameters @@ -194,19 +227,12 @@ def list_inputs(self, option='all'): """ size = _ctypes.c_uint() sarr = _ctypes.POINTER(_ctypes.c_char_p)() - if option == 'all': - copt = _ctypes.c_int(0) - elif option == 'read_only': - copt = _ctypes.c_int(1) - elif option == 'aux_state': - copt = _ctypes.c_int(2) - else: - raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}") _check_call(_LIB.NNSymbolListInputNames( - self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr))) + self.handle, self._get_list_copt(option), + _ctypes.byref(size), _ctypes.byref(sarr))) return [_base.py_str(sarr[i]) for i in range(size.value)] - def list_outputs(self): + def list_output_names(self): """List all outputs in the symbol. Returns diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index 266e2e79606b..fae58636deae 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -221,6 +221,25 @@ int NNSymbolListAttrs(SymbolHandle symbol, API_END(); } +int NNSymbolListInputVariables(SymbolHandle symbol, + int option, + nn_uint *out_size, + SymbolHandle** out_sym_array) { + Symbol *s = static_cast(symbol); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + API_BEGIN(); + std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); + ret->ret_handles.clear(); + for (size_t i = 0; i < vs.size(); ++i) { + nnvm::Symbol* rs = new nnvm::Symbol(); + rs->outputs.push_back(NodeEntry{vs[i], 0, 0}); + ret->ret_handles.push_back(rs); + } + *out_size = static_cast(vs.size()); + *out_sym_array = dmlc::BeginPtr(ret->ret_handles); + API_END(); +} + int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint *out_size, diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index ba6633043518..efb4a7a451f3 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -87,7 +87,8 @@ inline std::vector GetKeys( // whether the symbol is atomic functor inline bool IsAtomic(const std::vector& outputs) { - return outputs[0].node->inputs.size() == 0; + return outputs[0].node->inputs.size() == 0 && + outputs[0].node->control_deps.size() == 0; } // public functions @@ -118,7 +119,9 @@ Symbol Symbol::Copy() const { } void Symbol::Print(std::ostream &os) const { - if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) { + if (outputs.size() == 1 && + outputs[0].node->inputs.size() == 0 && + outputs[0].node->control_deps.size() == 0) { if (outputs[0].node->is_variable()) { os << "Variable:" << outputs[0].node->attrs.name << '\n'; } else { diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 0f3f57fd7cf4..c6460cab3ad0 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -69,6 +69,7 @@ Graph Gradient(Graph src) { // topo sort std::vector topo_order; std::unordered_map > output_grads; + DFSVisit(ys, [&](const NodePtr& node) { if (output_grads.count(node.get()) == 0) { output_grads[node.get()].resize(node->num_outputs()); @@ -113,13 +114,15 @@ Graph Gradient(Graph src) { e.sum = agg_fun(std::move(e.grads)); out_agg_grads.push_back(e.sum); } - std::vector input_grads = grad_fun_map[ptr->op()] - (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); - CHECK_EQ((*rit)->inputs.size(), input_grads.size()) - << "Gradient function not returning enough gradient"; - auto git = input_grads.begin(); - for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { - output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git)); + if ((*rit)->inputs.size() != 0) { + std::vector input_grads = grad_fun_map[ptr->op()] + (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); + CHECK_EQ((*rit)->inputs.size(), input_grads.size()) + << "Gradient function not returning enough gradient"; + auto git = input_grads.begin(); + for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { + output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git)); + } } } // take out the xs' grads diff --git a/nnvm/tests/python/test_graph.py b/nnvm/tests/python/test_graph.py index 9424a2ec66cc..bdf4a6aaca52 100644 --- a/nnvm/tests/python/test_graph.py +++ b/nnvm/tests/python/test_graph.py @@ -42,8 +42,8 @@ def test_list_args(): y = sym.add(y, z, name='add1') # write after read z = sym.assign(x, y, name='assign') - assert z.list_inputs('read_only') == ['conv_weight', 'z'] - assert z.list_inputs('aux_state') == ['x'] + assert z.list_input_names('read_only') == ['conv_weight', 'z'] + assert z.list_input_names('aux_state') == ['x'] def test_infer_shape(): x = sym.Variable('x', shape=(4, 2)) diff --git a/nnvm/tests/python/test_symbol.py b/nnvm/tests/python/test_symbol.py index 989cd99e730f..915ece3da69d 100644 --- a/nnvm/tests/python/test_symbol.py +++ b/nnvm/tests/python/test_symbol.py @@ -7,17 +7,19 @@ def test_compose(): y = sym.exp(sym.add(x, x, name='add', gpu=2), name='exp', gpu=1, attr={"kk": "1"}) - assert y.list_inputs() == ['x'] - assert y.list_outputs() == ["exp_output"] + assert y.list_input_names() == ['x'] + assert y.list_output_names() == ["exp_output"] assert y.list_attr()['gpu'] == '1' z = y.get_internals() - assert z['add_output'].list_outputs() == ['add_output'] + assert z['add_output'].list_output_names() == ['add_output'] assert y.list_attr(recursive=True)['add_gpu'] == '2' def test_default_input(): x = sym.Variable('x') y = sym.conv2d(data=x, name='conv') - assert y.list_inputs() == ['x', 'conv_weight'] + assert y.list_input_names() == ['x', 'conv_weight'] + tname = [z.list_output_names()[0] for z in y.list_input_variables()] + assert tname == y.list_input_names() try: z = sym.add(x) assert False