Skip to content

Commit

Permalink
Add initial implementation of the pretty printer (apache#17)
Browse files Browse the repository at this point in the history
* cp changes

* hack around lookup

* ignore core

* find problem

* find the reason of nullptr

* Update node.h

* fix lint

* commit

* fix

* fix
  • Loading branch information
MarisaKirisame authored and jroesch committed Aug 16, 2018
1 parent 53f5ac5 commit 3fe4cfa
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 6 deletions.
9 changes: 7 additions & 2 deletions relay/include/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <tvm/ir_functor.h>
#include <unordered_map>
#include <string>
#include "node.h"

namespace nnvm {
Expand All @@ -16,8 +17,12 @@ namespace relay {
struct Environment;

/*! \brief Integer literal `0`, `1000`. */
class EnvironmentNode : public ValueNode {
public:
struct EnvironmentNode : ValueNode {
std::unordered_map<std::string, GlobalId> table;
// What if there are two globalid with the same name?
// This should be fixed in the python code,
// But I havent take much look into it, so I will just hack around.

tvm::Map<GlobalId, Item> items;

EnvironmentNode() {}
Expand Down
1 change: 0 additions & 1 deletion relay/include/relay/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,6 @@ class DefnNode : public ItemNode {
public:
Type type;
Expr body;

DefnNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
Expand Down
46 changes: 46 additions & 0 deletions relay/include/relay/pretty_printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*!
* Copyright (c) 2018 by Contributors
* \file pretty_printer.h
* \brief A pretty printer for the Relay IR
*/
#ifndef NNVM_RELAY_PRETTY_PRINTER_H_
#define NNVM_RELAY_PRETTY_PRINTER_H_

#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "environment.h"
#include "expr_functor.h"
#include "node.h"

using std::ostream;

namespace nnvm {
namespace relay {

struct PrettyPrinter : public ExprFunctor<void(const Expr& n, ostream & os)> {
Environment env;
PrettyPrinter(Environment env) : env(env) {}
void PrettyPrint(const Expr & expr, ostream & os);
void VisitExpr_(const LocalIdNode * op, ostream & os) override;
void VisitExpr_(const GlobalIdNode * op, ostream & os) override;
void VisitExpr_(const IntrinsicIdNode * op, ostream & os) override;
void VisitExpr_(const FloatLitNode * op, ostream & os) override;
void VisitExpr_(const BoolLitNode * op, ostream & os) override;
void VisitExpr_(const IntLitNode * op, ostream & os) override;
void VisitExpr_(const TensorLitNode * op, ostream & os) override;
void VisitExpr_(const ProductLitNode * op, ostream & os) override;
void VisitExpr_(const CastNode * op, ostream & os) override;
void VisitExpr_(const ParamNode * op, ostream & os) override;
void VisitExpr_(const FunctionNode * op, ostream & os) override;
void VisitExpr_(const CallNode * op, ostream & os) override;
void VisitExpr_(const DebugNode * op, ostream & os) override;
void VisitExpr_(const UnaryOpNode * op, ostream & os) override;
void VisitExpr_(const BinaryOpNode * op, ostream & os) override;
void VisitExpr_(const AssignmentNode * op, ostream & os) override;
};

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_PRETTY_PRINTER_H_
5 changes: 5 additions & 0 deletions relay/python/relay/pretty_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""FFI constructors of all relay AST nodes."""

from tvm._ffi.function import _init_api

_init_api("nnvm.pretty_printer", __name__)
5 changes: 4 additions & 1 deletion relay/python/relay/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def visit_Return(self, return_node):
raise Exception("return must have a value")

def compile_stmt_seq_to_body(self, stmts):
return self.visit(stmts[0])
x = self.visit(stmts[0])
#assert x is not None
#todo(M.K.) somehow it is returning null. very bad.
return x

def run(self):
"""executes visitor"""
Expand Down
12 changes: 10 additions & 2 deletions relay/src/relay/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@ Environment EnvironmentNode::make(tvm::Map<GlobalId, Item> items) {
return Environment(n);
}

void EnvironmentNode::add(const Item &item) { this->items.Set(item->id, item); }
void EnvironmentNode::add(const Item &item) {
this->table.insert({item->id->name, item->id});
this->items.Set(item->id, item);
}

dmlc::optional<Item> EnvironmentNode::lookup(const GlobalIdNode *id) {
return dmlc::optional<Item>();
auto nit = this->table.find(id->name);
if (nit == this->table.end()) {
return dmlc::optional<Item>();
}
auto it = this->items.find((*nit).second);
return it == this->items.end() ? dmlc::optional<Item>() : dmlc::optional<Item>((*it).second);
}

TVM_REGISTER_API("nnvm.make.Environment")
Expand Down
136 changes: 136 additions & 0 deletions relay/src/relay/pretty_printer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*!
* Copyright (c) 2018 by Contributors
* \file pretty_printer.cc
* \brief A pretty printer for the Relay IR.
*/

#include <nnvm/relay/pretty_printer.h>
#include <vector>
#include <iostream>
#include <sstream>

namespace nnvm {
namespace relay {

using namespace tvm::runtime;

struct PrintError : std::exception {
std::string msg;
explicit PrintError(std::string msg) : msg(msg) {}

const char* what() const noexcept { return msg.c_str(); }
};

void PrettyPrinter::PrettyPrint(const Expr & expr, ostream & os) {
return this->operator()(expr, os);
}

void PrettyPrinter::VisitExpr_(const LocalIdNode * local, ostream & os) {
throw PrintError("printer for LocalIdNode NYI");
}

void PrettyPrinter::VisitExpr_(const GlobalIdNode * op, ostream & os) {
if (auto global = this->env->lookup(op)) {
Item i = *global;
if (const DefnNode* def = i.as<DefnNode>()) {
this->PrettyPrint(def->body, os);
} else {
throw PrintError("unknown global id");
}
} else {
throw PrintError("unknown global value");
}
}

void PrettyPrinter::VisitExpr_(const IntrinsicIdNode * op, ostream & os) {
throw PrintError("IntrinsicId NYI");
}

void PrettyPrinter::VisitExpr_(const FloatLitNode * op, ostream & os) {
os << op->value;
}

void PrettyPrinter::VisitExpr_(const BoolLitNode * op, ostream & os) {
os << op->value;
}

void PrettyPrinter::VisitExpr_(const IntLitNode * op, ostream & os) {
os << op->value;
}

void PrettyPrinter::VisitExpr_(const TensorLitNode * op, ostream & os) {
throw PrintError("printer for TENSOR NYI");
}

void PrettyPrinter::VisitExpr_(const ProductLitNode * op, ostream & os) {
throw PrintError("printer for PRODUCTLIT NYI");
}

void PrettyPrinter::VisitExpr_(const CastNode * op, ostream & os) {
this->PrettyPrint(op->node, os);
}

void PrettyPrinter::VisitExpr_(const ParamNode * op, ostream & os) {
throw PrintError("printer for param NYI");
}

void PrettyPrinter::VisitExpr_(const FunctionNode * op, ostream & os) {
os << "(\\";
if (op->body.get() == nullptr) {
// throw PrintError("empty function body");
// todo(M.K.) fix this
os << "undefined";
} else {
this->PrettyPrint(op->body, os);
}
os << ")";
}

void PrettyPrinter::VisitExpr_(const CallNode * op, ostream & os) {
throw PrintError("Call NYI");
}

void PrettyPrinter::VisitExpr_(const DebugNode * op, ostream & os) {
throw PrintError("Debug NYI");
}

void PrettyPrinter::VisitExpr_(const UnaryOpNode * op, ostream & os) {
switch (op->op) {
case UOp::NEG: {
os << "(";
os << "-";
this->PrettyPrint(op->node, os);
os << ")";
}
default:
throw PrintError("unknown UOP");
}
}

void PrettyPrinter::VisitExpr_(const BinaryOpNode * op, ostream & os) {
switch (op->op) {
case BinOp::PLUS: {
this->PrettyPrint(op->left, os);
os << "+";
this->PrettyPrint(op->right, os);
}
default:
throw PrintError("unknown BOP");
}
}

void PrettyPrinter::VisitExpr_(const AssignmentNode * op, ostream & os) {
throw PrintError("Assignment NYI");
}

TVM_REGISTER_API("nnvm.pretty_printer.pretty_print").set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0];
Expr expr = args[1];
PrettyPrinter pp(env);
std::ostringstream os;
pp.PrettyPrint(expr, os);
*ret = os.str();
});

} // namespace relay
} // namespace nnvm

0 comments on commit 3fe4cfa

Please sign in to comment.