From 1c30176d670d98308d81e5cc13a8ef2a753b5b73 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 12 Feb 2017 02:47:14 +0000 Subject: [PATCH 1/2] [PASS] Change IRVisitor interfaces to function override --- include/tvm/ir_visitor.h | 31 ++++++ src/pass/ir_visitor.cc | 232 ++++++++++++++++++++++----------------- 2 files changed, 164 insertions(+), 99 deletions(-) diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index e5711f65ff86..512b6bc289aa 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -46,6 +46,37 @@ class IRVisitor { virtual void Visit_(const Let* op); virtual void Visit_(const Free* op); virtual void Visit_(const Call* op); + virtual void Visit_(const Add* op); + virtual void Visit_(const Sub* op); + virtual void Visit_(const Mul* op); + virtual void Visit_(const Div* op); + virtual void Visit_(const Mod* op); + virtual void Visit_(const Min* op); + virtual void Visit_(const Max* op); + virtual void Visit_(const EQ* op); + virtual void Visit_(const NE* op); + virtual void Visit_(const LT* op); + virtual void Visit_(const LE* op); + virtual void Visit_(const GT* op); + virtual void Visit_(const GE* op); + virtual void Visit_(const And* op); + virtual void Visit_(const Or* op); + virtual void Visit_(const Reduce* op); + virtual void Visit_(const Cast* op); + virtual void Visit_(const Not* op); + virtual void Visit_(const Select* op); + virtual void Visit_(const Ramp* op); + virtual void Visit_(const Broadcast* op); + virtual void Visit_(const AssertStmt* op); + virtual void Visit_(const ProducerConsumer* op); + virtual void Visit_(const Provide* op); + virtual void Visit_(const Realize* op); + virtual void Visit_(const Block* op); + virtual void Visit_(const Evaluate* op); + virtual void Visit_(const IntImm* op); + virtual void Visit_(const UIntImm* op); + virtual void Visit_(const FloatImm* op); + virtual void Visit_(const StringImm* op); }; /*! diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 5baaa851970e..f811028a31c8 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -34,9 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) static FVisit inst; return inst; } -void NoOp(const NodeRef& n, IRVisitor* v) { -} - inline void VisitArray(const Array& arr, IRVisitor* v) { for (size_t i = 0; i < arr.size(); i++) { v->Visit(arr[i]); @@ -51,24 +48,6 @@ inline void VisitRDom(const Array& rdom, IRVisitor* v) { } } -#define DISPATCH_TO_VISIT(OP) \ - set_dispatch([](const OP* op, IRVisitor* v) { \ - v->Visit_(op); \ - }) - -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.DISPATCH_TO_VISIT(Variable) -.DISPATCH_TO_VISIT(LetStmt) -.DISPATCH_TO_VISIT(AttrStmt) -.DISPATCH_TO_VISIT(IfThenElse) -.DISPATCH_TO_VISIT(For) -.DISPATCH_TO_VISIT(Allocate) -.DISPATCH_TO_VISIT(Load) -.DISPATCH_TO_VISIT(Store) -.DISPATCH_TO_VISIT(Let) -.DISPATCH_TO_VISIT(Call) -.DISPATCH_TO_VISIT(Free); - void IRVisitor::Visit_(const Variable* op) {} void IRVisitor::Visit_(const LetStmt *op) { @@ -128,91 +107,146 @@ void IRVisitor::Visit_(const Call *op) { VisitArray(op->args, this); } -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const Reduce* op, IRVisitor* v) { - VisitRDom(op->axis, v); - v->Visit(op->source); - }) -.set_dispatch(NoOp) -.set_dispatch(NoOp) -.set_dispatch(NoOp) -.set_dispatch(NoOp); +#define DEFINE_BINOP_VISIT_(OP) \ + void IRVisitor::Visit_(const OP* op) { \ + this->Visit(op->a); \ + this->Visit(op->b); \ + } -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const Cast* op, IRVisitor* v) { - v->Visit(op->value); - }); +DEFINE_BINOP_VISIT_(Add) +DEFINE_BINOP_VISIT_(Sub) +DEFINE_BINOP_VISIT_(Mul) +DEFINE_BINOP_VISIT_(Div) +DEFINE_BINOP_VISIT_(Mod) +DEFINE_BINOP_VISIT_(Min) +DEFINE_BINOP_VISIT_(Max) +DEFINE_BINOP_VISIT_(EQ) +DEFINE_BINOP_VISIT_(NE) +DEFINE_BINOP_VISIT_(LT) +DEFINE_BINOP_VISIT_(LE) +DEFINE_BINOP_VISIT_(GT) +DEFINE_BINOP_VISIT_(GE) +DEFINE_BINOP_VISIT_(And) +DEFINE_BINOP_VISIT_(Or) + +void IRVisitor::Visit_(const Reduce* op) { + VisitRDom(op->axis, this); + this->Visit(op->source); +} -// binary operator -template -inline void Binary(const T* op, IRVisitor* v) { - v->Visit(op->a); - v->Visit(op->b); +void IRVisitor::Visit_(const Cast* op) { + this->Visit(op->value); } -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch
(Binary
) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary); +void IRVisitor::Visit_(const Not* op) { + this->Visit(op->a); +} -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const Not* op, IRVisitor* v) { - v->Visit(op->a); - }) -.set_dispatch([](const Select *op, const Expr& e, IRMutator* m) { - Expr cond = m->Mutate(op->condition); - Expr t = m->Mutate(op->true_value); - Expr f = m->Mutate(op->false_value); - if (cond.same_as(op->condition) && - t.same_as(op->true_value) && - f.same_as(op->false_value)) { - return e; - } else { - return Select::make(cond, t, f); - } - }) -.set_dispatch([](const Ramp *op, const Expr& e, IRMutator* m) { - Expr base = m->Mutate(op->base); - Expr stride = m->Mutate(op->stride); - if (base.same_as(op->base) && - stride.same_as(op->stride)) { - return e; - } else { - return Ramp::make(base, stride, op->lanes); - } - }) -.set_dispatch([](const Broadcast *op, const Expr& e, IRMutator* m) { - Expr value = m->Mutate(op->value); - if (value.same_as(op->value)) { - return e; - } else { - return Broadcast::make(value, op->lanes); - } - }); +.DISPATCH_TO_MUTATE_EXPR(Variable) +.DISPATCH_TO_MUTATE_EXPR(LetStmt) +.DISPATCH_TO_MUTATE_EXPR(AttrStmt) +.DISPATCH_TO_MUTATE_EXPR(IfThenElse) +.DISPATCH_TO_MUTATE_EXPR(For) +.DISPATCH_TO_MUTATE_EXPR(Allocate) +.DISPATCH_TO_MUTATE_EXPR(Load) +.DISPATCH_TO_MUTATE_EXPR(Store) +.DISPATCH_TO_MUTATE_EXPR(Let) +.DISPATCH_TO_MUTATE_EXPR(Free) +.DISPATCH_TO_MUTATE_EXPR(Call) +.DISPATCH_TO_MUTATE_EXPR(Add) +.DISPATCH_TO_MUTATE_EXPR(Sub) +.DISPATCH_TO_MUTATE_EXPR(Mul) +.DISPATCH_TO_MUTATE_EXPR(Div) +.DISPATCH_TO_MUTATE_EXPR(Mod) +.DISPATCH_TO_MUTATE_EXPR(Min) +.DISPATCH_TO_MUTATE_EXPR(Max) +.DISPATCH_TO_MUTATE_EXPR(EQ) +.DISPATCH_TO_MUTATE_EXPR(NE) +.DISPATCH_TO_MUTATE_EXPR(LT) +.DISPATCH_TO_MUTATE_EXPR(LE) +.DISPATCH_TO_MUTATE_EXPR(GT) +.DISPATCH_TO_MUTATE_EXPR(GE) +.DISPATCH_TO_MUTATE_EXPR(And) +.DISPATCH_TO_MUTATE_EXPR(Or) +.DISPATCH_TO_MUTATE_EXPR(Reduce) +.DISPATCH_TO_MUTATE_EXPR(Cast) +.DISPATCH_TO_MUTATE_EXPR(Not) +.DISPATCH_TO_MUTATE_EXPR(Select) +.DISPATCH_TO_MUTATE_EXPR(Ramp) +.DISPATCH_TO_MUTATE_EXPR(Broadcast) +.DISPATCH_TO_MUTATE_EXPR(AssertStmt) +.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer) +.DISPATCH_TO_MUTATE_EXPR(Provide) +.DISPATCH_TO_MUTATE_EXPR(Realize) +.DISPATCH_TO_MUTATE_EXPR(Block) +.DISPATCH_TO_MUTATE_EXPR(Evaluate) +.DISPATCH_TO_MUTATE_EXPR(IntImm) +.DISPATCH_TO_MUTATE_EXPR(UIntImm) +.DISPATCH_TO_MUTATE_EXPR(FloatImm) +.DISPATCH_TO_MUTATE_EXPR(StringImm); -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) -.set_dispatch([](const AssertStmt *op, const Stmt& s, IRMutator* m) { - Expr condition = m->Mutate(op->condition); - Expr message = m->Mutate(op->message); - - if (condition.same_as(op->condition) && message.same_as(op->message)) { - return s; - } else { - return AssertStmt::make(condition, message); - } - }) -.set_dispatch([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) { - Stmt body = m->Mutate(op->body); - if (body.same_as(op->body)) { - return s; - } else { - return ProducerConsumer::make(op->func, op->is_producer, body); - } - }) -.set_dispatch([](const Evaluate *op, const Stmt& s, IRMutator* m) { - Expr v = m->Mutate(op->value); - if (v.same_as(op->value)) { - return s; - } else { - return Evaluate::make(v); - } - }); } // namespace ir } // namespace tvm diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index f811028a31c8..f82f9130fca9 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -214,8 +214,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Load) .DISPATCH_TO_VISIT(Store) .DISPATCH_TO_VISIT(Let) -.DISPATCH_TO_VISIT(Call) .DISPATCH_TO_VISIT(Free) +.DISPATCH_TO_VISIT(Call) .DISPATCH_TO_VISIT(Add) .DISPATCH_TO_VISIT(Sub) .DISPATCH_TO_VISIT(Mul)