Skip to content

Commit

Permalink
Ghost nodes in NNVM graph (apache#3290)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx authored and wweic committed Jun 27, 2019
1 parent 092a675 commit 20a2d00
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
11 changes: 11 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ using FInferType = FInferNodeEntryAttr<int>;
*/
using TIsBackward = bool;

/*!
* \brief Whether this op is a ghost node.
* If TIsGhost is true:
* - The node with this op will not be visible in the indexed graph.
*
* \note Register under "TIsGhost"
* This enables shape/type inference for backward nodes when
* fusion is present.
*/
using TIsGhost = bool;

/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
Expand Down
3 changes: 3 additions & 0 deletions nnvm/src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {

DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) {
const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
CHECK(n);
Expand Down Expand Up @@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
inputs_rptr.push_back(input_entries_.size());
// control deps
for (const auto& nptr : n->control_deps) {
if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
auto it = node2index_.find(nptr.get());
CHECK(it != node2index_.end() && it->first == nptr.get());
control_deps_.push_back(it->second);
Expand Down

0 comments on commit 20a2d00

Please sign in to comment.