Skip to content

Commit

Permalink
Add type checker skeleton from sslyu (#7)
Browse files Browse the repository at this point in the history
Add type checker skeleton
  • Loading branch information
jroesch committed Aug 16, 2018
1 parent 16195c0 commit af85bb4
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
43 changes: 43 additions & 0 deletions relay/include/relay/typechecker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright (c) 2017 by Contributors
* \file environment.h
* \brief Relay typechecker
*/
#ifndef NNVM_RELAY_TYPECHECKER_H_
#define NNVM_RELAY_TYPECHECKER_H_

#include <unordered_map>
#include "node.h"
#include "expr_functor.h"
#include "environment.h"

namespace nnvm {
namespace relay {

class Typechecker : public ExprFunctor<Type(const Expr & n)> {
public:
Environment env;
Typechecker();
Typechecker(Environment env) : env(env) {}
Type Check(const Expr & expr);
Type VisitExpr_(const LocalIdNode* op);
Type VisitExpr_(const GlobalIdNode* op);
Type VisitExpr_(const IntrinsicIdNode* op);
Type VisitExpr_(const FloatLitNode* op);
Type VisitExpr_(const BoolLitNode* op);
Type VisitExpr_(const IntLitNode* op);
Type VisitExpr_(const TensorLitNode* op);
Type VisitExpr_(const ProductLitNode* op);
Type VisitExpr_(const CastNode* op);
Type VisitExpr_(const ParamNode* op);
Type VisitExpr_(const FunctionNode* op);
Type VisitExpr_(const CallNode* op);
Type VisitExpr_(const DebugNode* op);
Type VisitExpr_(const UnaryOpNode* op);
Type VisitExpr_(const BinaryOpNode* op);
Type VisitExpr_(const AssignmentNode* op);
};

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_TYPECHECKER_H_
5 changes: 5 additions & 0 deletions relay/python/relay/tyck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""FFI constructors of all relay AST nodes."""

from tvm._ffi.function import _init_api

_init_api("nnvm.tyck", __name__)
97 changes: 97 additions & 0 deletions relay/src/relay/typechecker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*!
* Copyright (c) 2018 by Contributors
* \file typechecker.cc
* \brief Relay typechecker
*/

#include <nnvm/relay/typechecker.h>
#include <vector>

namespace nnvm {
namespace relay {

using namespace tvm::runtime;

class TypecheckerError : public std::exception {
public:
std::string message;
explicit TypecheckerError(std::string message) : message(message) {}

const char *what() const noexcept { return this->message.c_str(); }
};

Typechecker::Typechecker() : env() {}

Type Typechecker::Check(const Expr &expr) { return this->VisitExpr(expr); }

Type Typechecker::VisitExpr_(const LocalIdNode *op) {
throw TypecheckerError("LocalIdNode not implemented");
}

Type Typechecker::VisitExpr_(const GlobalIdNode *op) {
throw TypecheckerError("GlobalIdNode not implemented");
}

Type Typechecker::VisitExpr_(const IntrinsicIdNode *op) {
throw TypecheckerError("IntrinsicIdNode not implemented");
}

Type Typechecker::VisitExpr_(const FloatLitNode *op) {
throw TypecheckerError("FloatLitNode not implemented");
}

Type Typechecker::VisitExpr_(const BoolLitNode *op) {
throw TypecheckerError("BoolLitNode not implemented");
}

Type Typechecker::VisitExpr_(const IntLitNode *op) {
throw TypecheckerError("IntLitNode not implemented");
}

Type Typechecker::VisitExpr_(const TensorLitNode *op) {
throw TypecheckerError("TensorLitNode not implemented");
}

Type Typechecker::VisitExpr_(const ProductLitNode *op) {
throw TypecheckerError("ProductLitNode not implemented");
}

Type Typechecker::VisitExpr_(const CastNode *op) {
throw TypecheckerError("CastNode not implemented");
}

Type Typechecker::VisitExpr_(const ParamNode *op) {
throw TypecheckerError("ParamNode not implemented");
}

Type Typechecker::VisitExpr_(const FunctionNode *op) {
throw TypecheckerError("FunctionNode not implemented");
}

Type Typechecker::VisitExpr_(const CallNode *op) {
throw TypecheckerError("CallNode not implemented");
}

Type Typechecker::VisitExpr_(const DebugNode *op) {
throw TypecheckerError("DebugNode not implemented");
}

Type Typechecker::VisitExpr_(const UnaryOpNode *op) {
throw TypecheckerError("UnaryOpNode not implemented");
}

Type Typechecker::VisitExpr_(const BinaryOpNode *op) {
throw TypecheckerError("BinaryOpNode not implmented");
}

Type Typechecker::VisitExpr_(const AssignmentNode *op) {
throw TypecheckerError("AssignmentNode not implemented");
}

TVM_REGISTER_API("nnvm.tyck.check")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntTypeNode::make(10);
});

} // namespace relay
} // namespace nnvm

0 comments on commit af85bb4

Please sign in to comment.