From 4b64f6ec4704d7b73bb5152936db0895cd4e25ac Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 30 Aug 2016 16:41:31 -0700 Subject: [PATCH] Enable copy on write in graph attrs (#31) * [INFER] Enhance backward op policy * [SYMBOL] add list inputs * relax graph attr to enable copy-on-write --- nnvm/include/nnvm/graph.h | 38 +++++++++++++++++++++++++++---- nnvm/include/nnvm/op_attr_types.h | 13 ++++++----- nnvm/include/nnvm/symbolic.h | 9 ++++++++ nnvm/include/nnvm/tuple.h | 2 +- nnvm/src/core/symbolic.cc | 27 ++++++++++++++-------- nnvm/src/pass/infer_shape_type.cc | 17 ++++++++------ nnvm/src/pass/place_device.cc | 12 ++++++++-- nnvm/src/pass/saveload_json.cc | 8 +++---- 8 files changed, 93 insertions(+), 33 deletions(-) diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 47e95f3b39250..5147b22e4aeca 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -30,18 +30,34 @@ class Graph { std::vector outputs; /*! * \brief attributes of a graph - * Each attribute is immutable, - * and can be shared across multiple Instance of graph + * Note that attribute is shared pointer and can be shared across graphs. + * + * It is highly recommended to keep each attribute immutable. + * It is also safe to implement an copy-on-write semnatics. + * + * Copy when shared_ptr.unique is not true, while reuse original space + * when shared_ptr.unique is true. */ - std::unordered_map > attrs; + std::unordered_map > attrs; /*! - * \brief Get the attribute from attrs. + * \brief Get the immutable attribute from attrs. * \param attr_name the name of the attribute * \return the reference to corresponding attribute * \tparam T the type of the attribute. */ template inline const T& GetAttr(const std::string& attr_name); + /*! + * \brief Get a move copy of the attribute, implement copy on write semantics. + * The content is moved if the reference counter of shared_ptr is 1. + * The attribute is erased from attrs after the call. + * + * \param attr_name the name of the attribute + * \return a new copy of the corresponding attribute. + * \tparam T the type of the attribute. + */ + template + inline T MoveCopyAttr(const std::string& attr_name); /*! * \brief get a indexed graph of current graph, if not exist, create it on demand * \return The indexed graph. @@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) { return nnvm::get(*it->second); } +template +inline T Graph::MoveCopyAttr(const std::string& attr_name) { + auto it = attrs.find(attr_name); + CHECK(it != attrs.end()) + << "Cannot find attribute " << attr_name << " in the graph"; + std::shared_ptr sptr = it->second; + attrs.erase(it); + if (sptr.unique()) { + return std::move(nnvm::get(*sptr)); + } else { + return nnvm::get(*sptr); + } +} + template diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index d3129f978cc2e..4d10c304a5d60 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -82,17 +82,18 @@ using FInferShape = FInferNodeEntryAttr; using FInferType = FInferNodeEntryAttr; /*! - * \brief Whether this op is an explicit backward operator + * \brief Whether this op is an explicit backward operator, + * and the correspondence of each output to input. * - * If TIsBackwardOp is set to be true: + * If FBackwardOutToInIndex exists: * - The first control_deps of the node points to the corresponding forward operator. - * - The outputs operator corresponds to exactly inputs of forward op one by one. - * - * \note Register under "TIsBackwardOp", default to false. + * - The k-th outputs corresponds to the FBackwardOutputToInputIndex()[k]-th input of forward op. * + * \note Register under "FBackwardOutToInIndex" * This enables easier shape/type inference for backward operators for slice and reduction. */ -using TIsBackwardOp = bool; +using FBackwardOutToInIndex = std::function< + std::vector (const NodeAttrs& attrs)>; /*! * \brief Get possible inplace options. diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index 8bca4cb3103c0..e153945e65a7d 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -62,6 +62,15 @@ class Symbol { * \return the symbol corresponds to the indexed element. */ Symbol operator[] (size_t index) const; + /*! + * \brief List the input variable nodes + * \param option The options to list the arguments. + * + * The position of the returned list also corresponds to calling position in operator() + * \return the arguments list of this symbol, they can be either named or unnamed (empty string). + * \sa ListInputOption + */ + std::vector ListInputs(ListInputOption option) const; /*! * \brief List the input names. * \param option The options to list the arguments. diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index fefb7ce5739de..326a5d9c6da60 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -233,7 +233,7 @@ class Tuple { return is; } } - index_t idx; + ValueType idx; std::vector tmp; while (is >> idx) { tmp.push_back(idx); diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 97b35648ee7dd..8da33faa6f541 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -180,37 +180,46 @@ Symbol Symbol::operator[] (size_t index) const { } } -std::vector Symbol::ListInputNames(ListInputOption option) const { - std::vector ret; +std::vector Symbol::ListInputs(ListInputOption option) const { + std::vector ret; if (option == kAll) { DFSVisit(this->outputs, [&ret](const NodePtr &node) { if (node->is_variable()) { - ret.push_back(node->attrs.name); + ret.push_back(node); } }); } else { std::unordered_set mutable_set; - std::vector vlist; + std::vector vlist; static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { if (node->is_variable()) { - vlist.push_back(node.get()); + vlist.push_back(node); } else if (fmutate_inputs.count(node->op())) { for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ mutable_set.insert(node->inputs[i].node.get()); } } }); - for (Node* node : vlist) { - if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) || - (option == kAuxiliaryStates && mutable_set.count(node) != 0)) { - ret.push_back(node->attrs.name); + for (const NodePtr& node : vlist) { + if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || + (option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) { + ret.emplace_back(node); } } } return ret; } +std::vector Symbol::ListInputNames(ListInputOption option) const { + std::vector inputs = ListInputs(option); + std::vector ret(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + ret[i] = inputs[i]->attrs.name; + } + return ret; +} + std::vector Symbol::ListOutputNames() const { static auto& flist_ouputs = Op::GetAttr("FListOutputNames"); std::vector ret; diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index a5cc8c13751f1..bb50e98b5ede9 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -24,8 +24,8 @@ Graph InferAttr(Graph &&ret, const IndexedGraph& idx = ret.indexed_graph(); static auto& finfer_shape = Op::GetAttr >(infer_name); - static auto& is_backward = - Op::GetAttr("TIsBackwardOp"); + static auto& backward_map = + Op::GetAttr("FBackwardOutToInIndex"); // reshape shape vector AttrVector rshape(idx.num_node_entries(), def_value); @@ -82,16 +82,19 @@ Graph InferAttr(Graph &&ret, for (uint32_t i = 0; i < num_outputs; ++i) { rshape[idx.entry_id(nid, i)] = oshape[i]; } - } else if (is_backward.get(inode.source->op(), false)) { + } else if (backward_map.count(inode.source->op())) { // backward operator inference. CHECK_GE(inode.control_deps.size(), 1) << "BackwardOp need to have control_deps to its forward op"; const auto& fnode = idx[inode.control_deps[0]]; - CHECK_EQ(fnode.inputs.size(), num_outputs) - << "BackwardOp need to correspond to the forward node"; + std::vector out_map = + backward_map[inode.source->op()](inode.source->attrs); bool known = true; - for (size_t i = 0; i < fnode.inputs.size(); ++i) { - rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])]; + for (size_t i = 0; i < out_map.size(); ++i) { + uint32_t in_id = out_map[i]; + CHECK_LT(in_id, fnode.inputs.size()); + rshape[idx.entry_id(nid, i)] = + rshape[idx.entry_id(fnode.inputs[in_id])]; if (fis_none(rshape[idx.entry_id(nid, i)])) known = false; } num_unknown += !known; diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index 402f2cff784c1..607c51a7f319d 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -12,6 +12,7 @@ namespace nnvm { namespace pass { namespace { + // simply logic to place device according to device_group hint // insert copy node when there is Graph PlaceDevice(Graph src) { @@ -21,13 +22,20 @@ Graph PlaceDevice(Graph src) { << "Need graph attribute \"device_assign_map\" in PlaceDevice"; CHECK_NE(src.attrs.count("device_copy_op"), 0) << "Need graph attribute \"device_copy_op\" in PlaceDevice"; - std::string device_group_attr_key = src.GetAttr("device_group_attr_key"); const Op* copy_op = Op::Get(src.GetAttr("device_copy_op")); auto& device_assign_map = src.GetAttr("device_assign_map"); const IndexedGraph& idx = src.indexed_graph(); - DeviceVector device(idx.num_nodes(), -1); + DeviceVector device; + // copy on write semanatics + if (src.attrs.count("device") != 0) { + device = src.MoveCopyAttr("device"); + CHECK_EQ(device.size(), idx.num_nodes()); + } else { + device.resize(idx.num_nodes(), -1); + } + // forward pass for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 681daed7a1ddc..3b3f8561a33fe 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -12,11 +12,11 @@ namespace dmlc { namespace json { // overload handler for shared ptr template<> -struct Handler > { - inline static void Write(JSONWriter *writer, const std::shared_ptr &data) { +struct Handler > { + inline static void Write(JSONWriter *writer, const std::shared_ptr &data) { writer->Write(*data); } - inline static void Read(JSONReader *reader, std::shared_ptr *data) { + inline static void Read(JSONReader *reader, std::shared_ptr *data) { any v; reader->Read(&v); *data = std::make_shared(std::move(v)); @@ -131,7 +131,7 @@ struct JSONGraph { std::vector arg_nodes; std::vector node_row_ptr; std::vector heads; - std::unordered_map > attrs; + std::unordered_map > attrs; void Save(dmlc::JSONWriter *writer) const { writer->BeginObject();