Skip to content

Commit

Permalink
Bugfix plan memory, fully support mxnet executor (apache#32)
Browse files Browse the repository at this point in the history
* [PASS] include knullop info in plan memory

* Bugfix plan memory, fully support mxnet
  • Loading branch information
tqchen committed May 29, 2018
1 parent 7c3d18c commit 9a4e133
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
12 changes: 12 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ namespace nnvm {
*/
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;

/*!
* \brief Return number of visible outputs by the user.
*
* \param attrs The attributes of the node.
*
* \note Register under "FNumVisibleOutputs", default not registered.
* This can be used to hide certain output from the user,
* but the additional outputs can be used to pass information from
* forward to gradient pass.
*/
using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>;

/*!
* \brief Return list of output arguments names of each operator.
*
Expand Down
19 changes: 15 additions & 4 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ inline std::vector<std::string> GetKeys(

// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs.size() == 1 && outputs[0].node->inputs.size() == 0;
return outputs[0].node->inputs.size() == 0;
}

// public functions
Expand Down Expand Up @@ -222,6 +222,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {

std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");

std::vector<std::string> ret;
for (auto &head : outputs) {
if (head.node->is_variable()) {
Expand Down Expand Up @@ -256,8 +257,6 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");

CHECK_EQ(outputs.size(), 1)
<< "Only composition of value function is supported currently";
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
Expand Down Expand Up @@ -400,6 +399,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
}

Symbol Symbol::GetInternals() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
DFSVisit(this->outputs, [&ret](const NodePtr& node) {
Node* n = node.get();
Expand All @@ -409,6 +409,9 @@ Symbol Symbol::GetInternals() const {
ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
} else {
uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i, 0});
}
Expand Down Expand Up @@ -467,14 +470,22 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op

Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s;
NodePtr n = Node::Create();
n->attrs.op = op;
n->attrs.dict = std::move(attrs);
if (n->op()->attr_parser != nullptr) {
n->op()->attr_parser(&(n->attrs));
}
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});

uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
}
return s;
}

Expand Down
1 change: 0 additions & 1 deletion nnvm/src/pass/place_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ namespace nnvm {
namespace pass {
namespace {


// simply logic to place device according to device_group hint
// insert copy node when there is
Graph PlaceDevice(Graph src) {
Expand Down
7 changes: 5 additions & 2 deletions nnvm/src/pass/plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ Graph PlanMemory(Graph ret) {
// step 1: initialize reference count
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
for (const auto& e : idx[nid].inputs) {
++ref_count[e.node_id];
++ref_count[idx.entry_id(e)];
}
}
for (const auto& e : idx.outputs()) {
++ref_count[e.node_id];
++ref_count[idx.entry_id(e)];
}
// step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1);
Expand Down Expand Up @@ -202,10 +202,13 @@ Graph PlanMemory(Graph ret) {
}
}
// check if there are outputs that can be freeded immediately
// these output are not referenced by any operator.
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator.Release(storage[eid], nid);
// use -2 to indicate that the node was never touched.
storage_inplace_index[eid] = -2;
}
if (storage[eid] == GraphAllocator::kBadStorageID) {
++num_not_allocated;
Expand Down

0 comments on commit 9a4e133

Please sign in to comment.