Skip to content

Commit

Permalink
[PASS] PrintGraphIR Join attributes when print ir (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 9c529cb commit ddd23a8
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 5 deletions.
19 changes: 17 additions & 2 deletions nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,23 @@ def index(self):
self._index = GraphIndex(self)
return self._index

def graphir(self):
"""Get text form of graph ir."""
def ir(self, join_entry_attrs=None, join_node_attrs=None):
"""Get text form of graph ir.
Parameters
----------
join_entry_attrs : list of str
List of graph NodeEntry attribute to be
printed along each operator.
join_node_attrs : list of str
List of graph node attribute to be
printed along each operator.
"""
if join_entry_attrs:
self._set_json_attr("join_entry_attrs", join_entry_attrs, "list_str")
if join_node_attrs:
self._set_json_attr("join_node_attrs", join_node_attrs, "list_str")
return self.apply("PrintGraphIR").json_attr("graphir")

def apply(self, passes):
Expand Down
2 changes: 2 additions & 0 deletions nnvm/src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ Graph InferAttr(Graph &&ret,
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
// erase the provided arguments
ret.attrs.erase(attr_key_name);
} else {
shape_attr_key = attr_name;
}
// Temp space for shape inference.
std::vector<AttrType> ishape, oshape;
Expand Down
83 changes: 81 additions & 2 deletions nnvm/src/pass/print_graph_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,80 @@
*/
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <nnvm/tuple.h>
#include <iostream>

namespace nnvm {
namespace pass {

using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>; // NOLINT(*)

template<typename T>
AttrPrinter GetVectorPrinter_(const T& vec) {
return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*)
os << vec[index];
};
}

AttrPrinter GetVectorPrinter(const Graph& graph,
const std::string& key) {
auto it = graph.attrs.find(key);
CHECK(it != graph.attrs.end())
<< "Cannot find " << key << " in graph attr";
const any& value = *(it->second);
if (value.type() == typeid(std::vector<TShape>)) {
return GetVectorPrinter_(
nnvm::get<std::vector<TShape> >(value));
} else if (value.type() == typeid(std::vector<int>)) {
return GetVectorPrinter_(
nnvm::get<std::vector<int> >(value));
} else if (value.type() == typeid(std::vector<std::string>)) {
return GetVectorPrinter_(
nnvm::get<std::vector<std::string> >(value));
} else {
LOG(FATAL) << "Cannot handle type " << value.type().name();
return nullptr;
}
}


// print the graph ir in readable format
void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
void PrintGraphIR_(Graph src,
const std::vector<std::string>& join_entry_attrs,
const std::vector<std::string>& join_node_attrs,
std::ostream& os) { // NOLINT(*)
const IndexedGraph& idx = src.indexed_graph();
std::vector<std::function<void(uint32_t, std::ostream&)> > trigger; // NOLINT(*)

for (const std::string& key : join_entry_attrs) {
AttrPrinter fp = GetVectorPrinter(src, key);
auto fprint = [&idx, key, fp](
uint32_t nid, std::ostream& os) { // NOLINT(*)
const IndexedGraph::Node& inode = idx[nid];
os << ", " << key << "=";
if (inode.source->num_outputs() != 1) {
os << '[';
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
if (i != 0) os << ", ";
fp(idx.entry_id(nid, i), os);
}
os << ']';
} else {
fp(idx.entry_id(nid, 0), os);
}
};
trigger.push_back(fprint);
}
for (const std::string& key : join_node_attrs) {
AttrPrinter fp = GetVectorPrinter(src, key);
auto fprint = [&idx, key, fp](
uint32_t nid, std::ostream& os) { // NOLINT(*)
os << key << "=";
fp(idx.entry_id(nid, 0), os);
};
trigger.push_back(fprint);
}

os << "Graph(";
if (idx.input_nodes().size() < 4) {
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
Expand Down Expand Up @@ -79,6 +145,10 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
}
os << "]";
}
// additional attribute trigger
for (const auto& fp : trigger) {
fp(nid, os);
}
os << "\n";
}
os << " ret ";
Expand Down Expand Up @@ -112,7 +182,16 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
// save a graph to json
Graph PrintGraphIR(Graph src) {
std::ostringstream os;
PrintGraphIR_(src, os);
std::vector<std::string> join_entry_attrs, join_node_attrs;
if (src.attrs.count("join_entry_attrs") != 0) {
join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >(
"join_entry_attrs");
}
if (src.attrs.count("join_node_attrs") != 0) {
join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >(
"join_node_attrs");
}
PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
Graph ret;
ret.attrs["graphir"] = std::make_shared<any>(os.str());
return ret;
Expand Down
2 changes: 1 addition & 1 deletion nnvm/tests/python/compiler/test_simplify_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def check(dim, axis, nstep):
graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
# Some prints for debug
# print(g1.graphir())
# print(g1.ir())
# assert graph equals as expected
graph_pass.check_graph_equal(g1, g2)

Expand Down
11 changes: 11 additions & 0 deletions nnvm/tests/python/unittest/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,19 @@ def test_plan_memory():
assert (storage_id[jnode_row_ptr[nindex["add2"]]] ==
storage_id[jnode_row_ptr[nindex["reshapek"]]])

def test_print_graph_ir():
x = sym.Variable("x", shape=(1, 1, 10, 20))
y = sym.conv2d(x + 1, name="y", channels=10, kernel_size=(3,3))
g = graph.create(y)
g = g.apply("InferShape")
ir1 = g.ir()
ir2 = g.ir(join_entry_attrs=["shape"])
assert("y_bias" in ir1)
assert("shape=" in ir2)


if __name__ == "__main__":
test_print_graph_ir()
test_json_pass_with_attr()
test_graph_json_attr()
test_json_pass()
Expand Down

0 comments on commit ddd23a8

Please sign in to comment.