forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add type checker skeleton from sslyu (#7)
Add type checker skeleton
- Loading branch information
Showing
3 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file environment.h | ||
* \brief Relay typechecker | ||
*/ | ||
#ifndef NNVM_RELAY_TYPECHECKER_H_ | ||
#define NNVM_RELAY_TYPECHECKER_H_ | ||
|
||
#include <unordered_map> | ||
#include "node.h" | ||
#include "expr_functor.h" | ||
#include "environment.h" | ||
|
||
namespace nnvm { | ||
namespace relay { | ||
|
||
class Typechecker : public ExprFunctor<Type(const Expr & n)> { | ||
public: | ||
Environment env; | ||
Typechecker(); | ||
Typechecker(Environment env) : env(env) {} | ||
Type Check(const Expr & expr); | ||
Type VisitExpr_(const LocalIdNode* op); | ||
Type VisitExpr_(const GlobalIdNode* op); | ||
Type VisitExpr_(const IntrinsicIdNode* op); | ||
Type VisitExpr_(const FloatLitNode* op); | ||
Type VisitExpr_(const BoolLitNode* op); | ||
Type VisitExpr_(const IntLitNode* op); | ||
Type VisitExpr_(const TensorLitNode* op); | ||
Type VisitExpr_(const ProductLitNode* op); | ||
Type VisitExpr_(const CastNode* op); | ||
Type VisitExpr_(const ParamNode* op); | ||
Type VisitExpr_(const FunctionNode* op); | ||
Type VisitExpr_(const CallNode* op); | ||
Type VisitExpr_(const DebugNode* op); | ||
Type VisitExpr_(const UnaryOpNode* op); | ||
Type VisitExpr_(const BinaryOpNode* op); | ||
Type VisitExpr_(const AssignmentNode* op); | ||
}; | ||
|
||
} // namespace relay | ||
} // namespace nnvm | ||
#endif // NNVM_RELAY_TYPECHECKER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""FFI constructors of all relay AST nodes.""" | ||
|
||
from tvm._ffi.function import _init_api | ||
|
||
_init_api("nnvm.tyck", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file typechecker.cc | ||
* \brief Relay typechecker | ||
*/ | ||
|
||
#include <nnvm/relay/typechecker.h> | ||
#include <vector> | ||
|
||
namespace nnvm { | ||
namespace relay { | ||
|
||
using namespace tvm::runtime; | ||
|
||
class TypecheckerError : public std::exception { | ||
public: | ||
std::string message; | ||
explicit TypecheckerError(std::string message) : message(message) {} | ||
|
||
const char *what() const noexcept { return this->message.c_str(); } | ||
}; | ||
|
||
Typechecker::Typechecker() : env() {} | ||
|
||
Type Typechecker::Check(const Expr &expr) { return this->VisitExpr(expr); } | ||
|
||
Type Typechecker::VisitExpr_(const LocalIdNode *op) { | ||
throw TypecheckerError("LocalIdNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const GlobalIdNode *op) { | ||
throw TypecheckerError("GlobalIdNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const IntrinsicIdNode *op) { | ||
throw TypecheckerError("IntrinsicIdNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const FloatLitNode *op) { | ||
throw TypecheckerError("FloatLitNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const BoolLitNode *op) { | ||
throw TypecheckerError("BoolLitNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const IntLitNode *op) { | ||
throw TypecheckerError("IntLitNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const TensorLitNode *op) { | ||
throw TypecheckerError("TensorLitNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const ProductLitNode *op) { | ||
throw TypecheckerError("ProductLitNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const CastNode *op) { | ||
throw TypecheckerError("CastNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const ParamNode *op) { | ||
throw TypecheckerError("ParamNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const FunctionNode *op) { | ||
throw TypecheckerError("FunctionNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const CallNode *op) { | ||
throw TypecheckerError("CallNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const DebugNode *op) { | ||
throw TypecheckerError("DebugNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const UnaryOpNode *op) { | ||
throw TypecheckerError("UnaryOpNode not implemented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const BinaryOpNode *op) { | ||
throw TypecheckerError("BinaryOpNode not implmented"); | ||
} | ||
|
||
Type Typechecker::VisitExpr_(const AssignmentNode *op) { | ||
throw TypecheckerError("AssignmentNode not implemented"); | ||
} | ||
|
||
TVM_REGISTER_API("nnvm.tyck.check") | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntTypeNode::make(10); | ||
}); | ||
|
||
} // namespace relay | ||
} // namespace nnvm |