-
Notifications
You must be signed in to change notification settings - Fork 281
Conversation
what is the use-case? |
This is for implementing a control flow operator in MXNet. Please see apache/mxnet#10451 as an example. I think this functionality is useful in general. If we want to implement a high-order function like the ones in tensorflow, we also need to pass a subgraph to an operator. |
include/nnvm/node.h
Outdated
@@ -80,6 +82,7 @@ struct NodeAttrs { | |||
* For place holder variable, op == nullptr. | |||
*/ | |||
const Op *op{nullptr}; | |||
std::shared_ptr<Graph> g; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put this to the end of the current NodeAttrs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments, the field name g is a bit too generic. Need to comment on what is the semantics here, is it just an attribute so that we can use to to compose a high order AST, or something else
include/nnvm/op_attr_types.h
Outdated
@@ -176,6 +176,8 @@ using FSetInputVarAttrOnCompose = std::function<void( | |||
NodePtr var, | |||
const int index)>; | |||
|
|||
using FInputGraph = std::function<uint32_t(const NodeAttrs& attrs)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document what it is
src/core/symbolic.cc
Outdated
// compositional logic | ||
void Symbol::Compose(const array_view<const Symbol*>& args, | ||
const std::unordered_map<std::string, const Symbol*>& kwargs, | ||
const std::string& name) { | ||
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames"); | ||
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose"); | ||
static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph"); | ||
|
||
Node* n = outputs[0].node.get(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment the behavior on what is expected when a node contains input graph.
@tqchen Does the PR look OK? Should I add more tests? Do you have any suggestions on how to move forward? Thanks |
cc @piiswrong can you also do a review? |
include/nnvm/node.h
Outdated
* mini-batches. In this sense, the subgraphs are kind of similar to | ||
* the parameters and show be kept as node attributes. | ||
*/ | ||
std::vector<std::shared_ptr<Graph> > subgraphs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why shared ptr?
src/core/symbolic.cc
Outdated
for (size_t i = 0; i < gidx.num_nodes(); ++i) { | ||
for (const auto& j : gidx[i].inputs) ++ref_count[gidx.entry_id(j)]; | ||
} | ||
g->attrs["forward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont do this here
@piiswrong any updates on this? |
@tqchen @piiswrong suggests that I should test it with everything in MXNet first before merging this PR. I'm almost finishing foreach. I'll run all the tests by the week and run some tests for this PR. After that, can we merge this PR? |
@piiswrong @tqchen I have run through all unit tests in mxnet (tests/python/unittest/*). They all work fine. Will this be good enough for merging this PR? |
need to rebase master, and need @piiswrong 's approval |
for some reason, CI fails. I tried on my local machine, the tests work fine. |
3cbda39
to
9636c9b
Compare
It seems the CI fails after I add the field of subgraphs in NodeAttrs. Does NNVM do anything special on NodeAttrs? @piiswrong @tqchen |
@tqchen after rebase, everything works fine now. |
c.f. #518 we will redirect further changes to tvm repo, please open a new PR there. Please invite the original set of reviewers when the new PR is opened so they can review and approve the changes |
Fix markdown syntax error (code shifts out of markdown-code box).
This PR allows an operator to accept a graph symbol as input. The input graph will be stored inside the corresponding node.