From af85bb4b1578069d41385f1397fd593d1242230d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Apr 2018 23:11:29 -0700 Subject: [PATCH] Add type checker skeleton from sslyu (#7) Add type checker skeleton --- relay/include/relay/typechecker.h | 43 ++++++++++++++ relay/python/relay/tyck.py | 5 ++ relay/src/relay/typechecker.cc | 97 +++++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 relay/include/relay/typechecker.h create mode 100644 relay/python/relay/tyck.py create mode 100644 relay/src/relay/typechecker.cc diff --git a/relay/include/relay/typechecker.h b/relay/include/relay/typechecker.h new file mode 100644 index 0000000000000..e422e40b4d1eb --- /dev/null +++ b/relay/include/relay/typechecker.h @@ -0,0 +1,43 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file environment.h + * \brief Relay typechecker + */ +#ifndef NNVM_RELAY_TYPECHECKER_H_ +#define NNVM_RELAY_TYPECHECKER_H_ + +#include +#include "node.h" +#include "expr_functor.h" +#include "environment.h" + +namespace nnvm { +namespace relay { + +class Typechecker : public ExprFunctor { +public: + Environment env; + Typechecker(); + Typechecker(Environment env) : env(env) {} + Type Check(const Expr & expr); + Type VisitExpr_(const LocalIdNode* op); + Type VisitExpr_(const GlobalIdNode* op); + Type VisitExpr_(const IntrinsicIdNode* op); + Type VisitExpr_(const FloatLitNode* op); + Type VisitExpr_(const BoolLitNode* op); + Type VisitExpr_(const IntLitNode* op); + Type VisitExpr_(const TensorLitNode* op); + Type VisitExpr_(const ProductLitNode* op); + Type VisitExpr_(const CastNode* op); + Type VisitExpr_(const ParamNode* op); + Type VisitExpr_(const FunctionNode* op); + Type VisitExpr_(const CallNode* op); + Type VisitExpr_(const DebugNode* op); + Type VisitExpr_(const UnaryOpNode* op); + Type VisitExpr_(const BinaryOpNode* op); + Type VisitExpr_(const AssignmentNode* op); +}; + +} // namespace relay +} // namespace nnvm +#endif // NNVM_RELAY_TYPECHECKER_H_ diff --git a/relay/python/relay/tyck.py b/relay/python/relay/tyck.py new file mode 100644 index 0000000000000..cec2d8aa7802f --- /dev/null +++ b/relay/python/relay/tyck.py @@ -0,0 +1,5 @@ +"""FFI constructors of all relay AST nodes.""" + +from tvm._ffi.function import _init_api + +_init_api("nnvm.tyck", __name__) diff --git a/relay/src/relay/typechecker.cc b/relay/src/relay/typechecker.cc new file mode 100644 index 0000000000000..00833b4273b1e --- /dev/null +++ b/relay/src/relay/typechecker.cc @@ -0,0 +1,97 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file typechecker.cc + * \brief Relay typechecker + */ + +#include +#include + +namespace nnvm { +namespace relay { + +using namespace tvm::runtime; + +class TypecheckerError : public std::exception { + public: + std::string message; + explicit TypecheckerError(std::string message) : message(message) {} + + const char *what() const noexcept { return this->message.c_str(); } +}; + +Typechecker::Typechecker() : env() {} + +Type Typechecker::Check(const Expr &expr) { return this->VisitExpr(expr); } + +Type Typechecker::VisitExpr_(const LocalIdNode *op) { + throw TypecheckerError("LocalIdNode not implemented"); +} + +Type Typechecker::VisitExpr_(const GlobalIdNode *op) { + throw TypecheckerError("GlobalIdNode not implemented"); +} + +Type Typechecker::VisitExpr_(const IntrinsicIdNode *op) { + throw TypecheckerError("IntrinsicIdNode not implemented"); +} + +Type Typechecker::VisitExpr_(const FloatLitNode *op) { + throw TypecheckerError("FloatLitNode not implemented"); +} + +Type Typechecker::VisitExpr_(const BoolLitNode *op) { + throw TypecheckerError("BoolLitNode not implemented"); +} + +Type Typechecker::VisitExpr_(const IntLitNode *op) { + throw TypecheckerError("IntLitNode not implemented"); +} + +Type Typechecker::VisitExpr_(const TensorLitNode *op) { + throw TypecheckerError("TensorLitNode not implemented"); +} + +Type Typechecker::VisitExpr_(const ProductLitNode *op) { + throw TypecheckerError("ProductLitNode not implemented"); +} + +Type Typechecker::VisitExpr_(const CastNode *op) { + throw TypecheckerError("CastNode not implemented"); +} + +Type Typechecker::VisitExpr_(const ParamNode *op) { + throw TypecheckerError("ParamNode not implemented"); +} + +Type Typechecker::VisitExpr_(const FunctionNode *op) { + throw TypecheckerError("FunctionNode not implemented"); +} + +Type Typechecker::VisitExpr_(const CallNode *op) { + throw TypecheckerError("CallNode not implemented"); +} + +Type Typechecker::VisitExpr_(const DebugNode *op) { + throw TypecheckerError("DebugNode not implemented"); +} + +Type Typechecker::VisitExpr_(const UnaryOpNode *op) { + throw TypecheckerError("UnaryOpNode not implemented"); +} + +Type Typechecker::VisitExpr_(const BinaryOpNode *op) { + throw TypecheckerError("BinaryOpNode not implmented"); +} + +Type Typechecker::VisitExpr_(const AssignmentNode *op) { + throw TypecheckerError("AssignmentNode not implemented"); +} + +TVM_REGISTER_API("nnvm.tyck.check") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IntTypeNode::make(10); + }); + +} // namespace relay +} // namespace nnvm