forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial implementation of the pretty printer (apache#17)
* cp changes * hack around lookup * ignore core * find problem * find the reason of nullptr * Update node.h * fix lint * commit * fix * fix
- Loading branch information
1 parent
53f5ac5
commit 3fe4cfa
Showing
7 changed files
with
208 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file pretty_printer.h | ||
* \brief A pretty printer for the Relay IR | ||
*/ | ||
#ifndef NNVM_RELAY_PRETTY_PRINTER_H_ | ||
#define NNVM_RELAY_PRETTY_PRINTER_H_ | ||
|
||
#include <unordered_map> | ||
#include <string> | ||
#include <vector> | ||
#include <iostream> | ||
#include "environment.h" | ||
#include "expr_functor.h" | ||
#include "node.h" | ||
|
||
using std::ostream; | ||
|
||
namespace nnvm { | ||
namespace relay { | ||
|
||
struct PrettyPrinter : public ExprFunctor<void(const Expr& n, ostream & os)> { | ||
Environment env; | ||
PrettyPrinter(Environment env) : env(env) {} | ||
void PrettyPrint(const Expr & expr, ostream & os); | ||
void VisitExpr_(const LocalIdNode * op, ostream & os) override; | ||
void VisitExpr_(const GlobalIdNode * op, ostream & os) override; | ||
void VisitExpr_(const IntrinsicIdNode * op, ostream & os) override; | ||
void VisitExpr_(const FloatLitNode * op, ostream & os) override; | ||
void VisitExpr_(const BoolLitNode * op, ostream & os) override; | ||
void VisitExpr_(const IntLitNode * op, ostream & os) override; | ||
void VisitExpr_(const TensorLitNode * op, ostream & os) override; | ||
void VisitExpr_(const ProductLitNode * op, ostream & os) override; | ||
void VisitExpr_(const CastNode * op, ostream & os) override; | ||
void VisitExpr_(const ParamNode * op, ostream & os) override; | ||
void VisitExpr_(const FunctionNode * op, ostream & os) override; | ||
void VisitExpr_(const CallNode * op, ostream & os) override; | ||
void VisitExpr_(const DebugNode * op, ostream & os) override; | ||
void VisitExpr_(const UnaryOpNode * op, ostream & os) override; | ||
void VisitExpr_(const BinaryOpNode * op, ostream & os) override; | ||
void VisitExpr_(const AssignmentNode * op, ostream & os) override; | ||
}; | ||
|
||
} // namespace relay | ||
} // namespace nnvm | ||
#endif // NNVM_RELAY_PRETTY_PRINTER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""FFI constructors of all relay AST nodes.""" | ||
|
||
from tvm._ffi.function import _init_api | ||
|
||
_init_api("nnvm.pretty_printer", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file pretty_printer.cc | ||
* \brief A pretty printer for the Relay IR. | ||
*/ | ||
|
||
#include <nnvm/relay/pretty_printer.h> | ||
#include <vector> | ||
#include <iostream> | ||
#include <sstream> | ||
|
||
namespace nnvm { | ||
namespace relay { | ||
|
||
using namespace tvm::runtime; | ||
|
||
struct PrintError : std::exception { | ||
std::string msg; | ||
explicit PrintError(std::string msg) : msg(msg) {} | ||
|
||
const char* what() const noexcept { return msg.c_str(); } | ||
}; | ||
|
||
void PrettyPrinter::PrettyPrint(const Expr & expr, ostream & os) { | ||
return this->operator()(expr, os); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const LocalIdNode * local, ostream & os) { | ||
throw PrintError("printer for LocalIdNode NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const GlobalIdNode * op, ostream & os) { | ||
if (auto global = this->env->lookup(op)) { | ||
Item i = *global; | ||
if (const DefnNode* def = i.as<DefnNode>()) { | ||
this->PrettyPrint(def->body, os); | ||
} else { | ||
throw PrintError("unknown global id"); | ||
} | ||
} else { | ||
throw PrintError("unknown global value"); | ||
} | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const IntrinsicIdNode * op, ostream & os) { | ||
throw PrintError("IntrinsicId NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const FloatLitNode * op, ostream & os) { | ||
os << op->value; | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const BoolLitNode * op, ostream & os) { | ||
os << op->value; | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const IntLitNode * op, ostream & os) { | ||
os << op->value; | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const TensorLitNode * op, ostream & os) { | ||
throw PrintError("printer for TENSOR NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const ProductLitNode * op, ostream & os) { | ||
throw PrintError("printer for PRODUCTLIT NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const CastNode * op, ostream & os) { | ||
this->PrettyPrint(op->node, os); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const ParamNode * op, ostream & os) { | ||
throw PrintError("printer for param NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const FunctionNode * op, ostream & os) { | ||
os << "(\\"; | ||
if (op->body.get() == nullptr) { | ||
// throw PrintError("empty function body"); | ||
// todo(M.K.) fix this | ||
os << "undefined"; | ||
} else { | ||
this->PrettyPrint(op->body, os); | ||
} | ||
os << ")"; | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const CallNode * op, ostream & os) { | ||
throw PrintError("Call NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const DebugNode * op, ostream & os) { | ||
throw PrintError("Debug NYI"); | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const UnaryOpNode * op, ostream & os) { | ||
switch (op->op) { | ||
case UOp::NEG: { | ||
os << "("; | ||
os << "-"; | ||
this->PrettyPrint(op->node, os); | ||
os << ")"; | ||
} | ||
default: | ||
throw PrintError("unknown UOP"); | ||
} | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const BinaryOpNode * op, ostream & os) { | ||
switch (op->op) { | ||
case BinOp::PLUS: { | ||
this->PrettyPrint(op->left, os); | ||
os << "+"; | ||
this->PrettyPrint(op->right, os); | ||
} | ||
default: | ||
throw PrintError("unknown BOP"); | ||
} | ||
} | ||
|
||
void PrettyPrinter::VisitExpr_(const AssignmentNode * op, ostream & os) { | ||
throw PrintError("Assignment NYI"); | ||
} | ||
|
||
TVM_REGISTER_API("nnvm.pretty_printer.pretty_print").set_body([](TVMArgs args, TVMRetValue* ret) { | ||
Environment env = args[0]; | ||
Expr expr = args[1]; | ||
PrettyPrinter pp(env); | ||
std::ostringstream os; | ||
pp.PrettyPrint(expr, os); | ||
*ret = os.str(); | ||
}); | ||
|
||
} // namespace relay | ||
} // namespace nnvm |