Skip to content

Commit

Permalink
[IR] Refactor IR Printer: Type/Function/IRModule... to Doc (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-mxd authored and jc-bytedance committed Mar 28, 2023
1 parent 753226c commit 615544a
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 6 deletions.
4 changes: 2 additions & 2 deletions include/matxscript/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class PointerTypeNode : public TypeNode {
}

runtime::Unicode GetPythonTypeName() const override {
return U"pointer";
return U"matx.handle(" + element_type->GetPythonTypeName() + U")";
}

static constexpr const char* _type_key = "PointerType";
Expand Down Expand Up @@ -1005,7 +1005,7 @@ class IteratorTypeNode : public TypeNode {
}

runtime::Unicode GetPythonTypeName() const override {
return U"iterator";
return U"Iterable[" + container_type->GetPythonTypeName() + U"]";
}

static constexpr const char* _type_key = "IteratorType";
Expand Down
10 changes: 10 additions & 0 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@
#include <matxscript/ir/attrs.h>

#include <matxscript/ir/attr_functor.h>
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>
#include <matxscript/runtime/registry.h>

namespace matxscript {
namespace ir {

using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
Expand All @@ -57,6 +62,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << op->dict;
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<DictAttrs>("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc {
return d->AsDoc(attrs->dict, p->Attr("dict"));
});

MATXSCRIPT_REGISTER_NODE_TYPE(DictAttrsNode);

MATXSCRIPT_REGISTER_NODE_TYPE(AttrFieldInfoNode);
Expand Down
118 changes: 118 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@
*/
#include <matxscript/ir/function.h>

#include <matxscript/ir/_base/with.h>
#include <matxscript/ir/prim_ops.h>
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>
#include <matxscript/ir/type.h>
#include <matxscript/runtime/registry.h>

namespace matxscript {
namespace ir {

using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

bool BaseFuncNode::HasGlobalName() const {
auto global_symbol = GetAttr<StringRef>(attr::kGlobalSymbol);
Expand Down Expand Up @@ -171,6 +177,48 @@ MATXSCRIPT_REGISTER_GLOBAL("ir.PrimFunc")
std::move(attrs));
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::PrimFunc>("", [](ir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<IRFrame> f(d, func);
(*f)->AddDispatchToken(d, "ir");
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<ir::PrimVarNode>() || obj->IsInstance<ir::BufferNode>();
});
int n_args = func->params.size();
// Step 1. Handle `func->params`
int default_begin_pos = func->params.size() - func->default_params.size();
Array<AssignDoc> args;
args.reserve(n_args);
for (int i = 0; i < n_args; ++i) {
ir::PrimVar var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
Optional<ExprDoc> rhs = NullOpt;
if (i >= default_begin_pos) {
int def_pos = i - default_begin_pos;
rhs = d->AsDoc<ExprDoc>(func->default_params[def_pos],
p->Attr("default_params")->ArrayIndex(def_pos));
}
args.push_back(AssignDoc(DefineVar(var, *f, d), rhs, a));
}
// Step 2. Handle `func->attrs`
if (func->attrs.defined() && !func->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(Dialect(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
}
// Step 3. Handle `func->body`
AsDocBody(func->body, p->Attr("body"), f->get(), d);
Optional<ExprDoc> ret_type = NullOpt;
ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type"));
return FunctionDoc(
/*name=*/IdDoc(func->GetGlobalName()),
/*args=*/args,
/*decorators=*/{Dialect(d, "kernel")},
/*return_type=*/ret_type,
/*body=*/(*f)->stmts);
});

/******************************************************************************
* Function
*****************************************************************************/
Expand Down Expand Up @@ -261,6 +309,49 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ", " << node->type_params << ", " << node->attrs << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::Function>("", [](ir::Function func, ObjectPath p, IRDocsifier d) -> Doc {
With<IRFrame> f(d, func);
(*f)->AddDispatchToken(d, "ir");
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<ir::PrimVarNode>() || obj->IsInstance<ir::HLOVarNode>() ||
obj->IsInstance<ir::BufferNode>();
});
int n_args = func->params.size();
// Step 1. Handle `func->params`
int default_begin_pos = func->params.size() - func->default_params.size();
Array<AssignDoc> args;
args.reserve(n_args);
for (int i = 0; i < n_args; ++i) {
ir::BaseExpr var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
ExprDoc a = d->AsDoc<ExprDoc>(var->checked_type_, var_p->Attr("checked_type_"));
Optional<ExprDoc> rhs = NullOpt;
if (i >= default_begin_pos) {
int def_pos = i - default_begin_pos;
rhs = d->AsDoc<ExprDoc>(func->default_params[def_pos],
p->Attr("default_params")->ArrayIndex(def_pos));
}
args.push_back(AssignDoc(DefineVar(var, *f, d), rhs, a));
}
// Step 2. Handle `func->attrs`
if (func->attrs.defined() && !func->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(Dialect(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
}
// Step 3. Handle `func->body`
AsDocBody(func->body, p->Attr("body"), f->get(), d);
Optional<ExprDoc> ret_type = NullOpt;
ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type"));
return FunctionDoc(
/*name=*/IdDoc(func->GetGlobalName()),
/*args=*/args,
/*decorators=*/{Dialect(d, "script")},
/*return_type=*/ret_type,
/*body=*/(*f)->stmts);
});

/******************************************************************************
* LambdaFunction
*****************************************************************************/
Expand Down Expand Up @@ -346,6 +437,33 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->body << ", " << node->captures << ", " << node->attrs << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::LambdaFunction>(
"", [](ir::LambdaFunction func, ObjectPath p, IRDocsifier d) -> Doc {
With<IRFrame> f(d, func);
(*f)->AddDispatchToken(d, "ir");
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<ir::PrimVarNode>() || obj->IsInstance<ir::HLOVarNode>() ||
obj->IsInstance<ir::BufferNode>();
});
int n_args = func->params.size();
// Step 1. Handle `func->params`
Array<IdDoc> args;
args.reserve(n_args);
for (int i = 0; i < n_args; ++i) {
ir::BaseExpr var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
IdDoc a = d->AsDoc<IdDoc>(var, var_p);
args.push_back(a);
}
// TODO: fix lambda doc
// Step 2. Handle `func->body`
auto body = d->AsDoc<ExprDoc>(func->body, p->Attr("body"));
return LambdaDoc(
/*args=*/args,
/*body=*/body);
});

/******************************************************************************
* BaseFunc
*****************************************************************************/
Expand Down
65 changes: 64 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <unordered_set>

#include <matxscript/ir/_base/structural_equal.h>
#include <matxscript/ir/_base/with.h>
#include <matxscript/runtime/registry.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
Expand All @@ -43,11 +44,16 @@
#include <matxscript/ir/analysis.h>
#include <matxscript/ir/expr_functor.h>
// clang-format on
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>

namespace matxscript {
namespace ir {

using namespace runtime;
using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

IRModule::IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, ClassType> type_definitions,
Expand Down Expand Up @@ -410,5 +416,62 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "IRModule(" << node->functions << ")";
});

struct SortableFunction {
int priority;
GlobalVar gv;
BaseFunc func;

explicit SortableFunction(const std::pair<GlobalVar, BaseFunc>& obj)
: priority(0), gv(obj.first), func(obj.second) {
if (gv->name_hint == "main") {
priority = 1000;
} else if (obj.second->GetTypeKey() == "ir.PrimFunc") {
priority = 1;
} else if (obj.second->GetTypeKey() == "ir.LambdaFunction") {
priority = 2;
} else if (obj.second->GetTypeKey() == "ir.Function") {
priority = 3;
} else {
MXLOG(FATAL) << "TypeError: MATX cannot print functions of type: "
<< obj.second->GetTypeKey();
}
}

bool operator<(const SortableFunction& other) const {
if (this->priority != other.priority) {
return this->priority < other.priority;
}
return this->gv->name_hint < other.gv->name_hint;
}
};

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<IRModule>("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc {
std::vector<SortableFunction> functions;
for (const auto& kv : mod->functions) {
functions.push_back(SortableFunction(kv));
}
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d, ObjectRef{nullptr});
(*f)->AddDispatchToken(d, "ir");
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
d->cfg->binding_names.push_back(gv->name_hint);
Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
(*f)->stmts.back()->source_paths = std::move(doc->source_paths);
} else if (const auto* stmt = doc.as<StmtDocNode>()) {
(*f)->stmts.push_back(GetRef<StmtDoc>(stmt));
} else {
(*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
}
}
// TODO: use ModuleDoc instead
return ClassDoc(IdDoc("Module"), {Dialect(d, "ir_module")}, (*f)->stmts);
});

} // namespace ir
} // namespace matxscript
10 changes: 10 additions & 0 deletions src/ir/op_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

#include <matxscript/ir/_base/attr_registry.h>
#include <matxscript/ir/op_expr.h>
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>
#include <matxscript/ir/type.h>
#include <matxscript/runtime/container.h>
#include <matxscript/runtime/object_internal.h>
Expand All @@ -38,6 +42,7 @@ namespace matxscript {
namespace ir {

using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

using OpRegistry = AttrRegistry<OpRegEntry, Op>;

Expand Down Expand Up @@ -125,6 +130,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Op(" << node->name << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
return Dialect(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
});

MATXSCRIPT_REGISTER_GLOBAL("ir._call_builtin_op").set_body([](PyArgs args) -> RTValue {
Type ret_type = args[0].As<Type>();
StringRef op_name = args[1].As<StringRef>();
Expand Down
6 changes: 3 additions & 3 deletions src/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<AllocaVarStmt>("", [](AllocaVarStmt s, ObjectPath p, IRDocsifier d) -> Doc {
ObjectPath var_attr = p->Attr("var");
auto lhs = d->AsDoc<ExprDoc>(s->var, var_attr);
ObjectPath var_p = p->Attr("var");
auto lhs = d->AsDoc<ExprDoc>(s->var, var_p);
Optional<ExprDoc> rhs;
Optional<ExprDoc> annotation;
if (s->init_value.defined()) {
rhs = d->AsDoc<ExprDoc>(s->init_value, p->Attr("init_value"));
}
annotation = LiteralDoc::HLOType(s->var->checked_type_, var_attr->Attr("checked_type_"));
annotation = d->AsDoc<ExprDoc>(s->var->checked_type_, var_p->Attr("checked_type_"));
return AssignDoc(lhs, rhs, annotation);
});

Expand Down
Loading

0 comments on commit 615544a

Please sign in to comment.