Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Refactor IR Printer: Type/Function/IRModule... to Doc #203

Merged
merged 1 commit into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/matxscript/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class PointerTypeNode : public TypeNode {
}

runtime::Unicode GetPythonTypeName() const override {
return U"pointer";
return U"matx.handle(" + element_type->GetPythonTypeName() + U")";
}

static constexpr const char* _type_key = "PointerType";
Expand Down Expand Up @@ -1005,7 +1005,7 @@ class IteratorTypeNode : public TypeNode {
}

runtime::Unicode GetPythonTypeName() const override {
return U"iterator";
return U"Iterable[" + container_type->GetPythonTypeName() + U"]";
}

static constexpr const char* _type_key = "IteratorType";
Expand Down
10 changes: 10 additions & 0 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@
#include <matxscript/ir/attrs.h>

#include <matxscript/ir/attr_functor.h>
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>
#include <matxscript/runtime/registry.h>

namespace matxscript {
namespace ir {

using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
Expand All @@ -57,6 +62,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << op->dict;
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<DictAttrs>("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc {
return d->AsDoc(attrs->dict, p->Attr("dict"));
});

MATXSCRIPT_REGISTER_NODE_TYPE(DictAttrsNode);

MATXSCRIPT_REGISTER_NODE_TYPE(AttrFieldInfoNode);
Expand Down
118 changes: 118 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@
*/
#include <matxscript/ir/function.h>

#include <matxscript/ir/_base/with.h>
#include <matxscript/ir/prim_ops.h>
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>
#include <matxscript/ir/type.h>
#include <matxscript/runtime/registry.h>

namespace matxscript {
namespace ir {

using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

bool BaseFuncNode::HasGlobalName() const {
auto global_symbol = GetAttr<StringRef>(attr::kGlobalSymbol);
Expand Down Expand Up @@ -171,6 +177,48 @@ MATXSCRIPT_REGISTER_GLOBAL("ir.PrimFunc")
std::move(attrs));
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::PrimFunc>("", [](ir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<IRFrame> f(d, func);
(*f)->AddDispatchToken(d, "ir");
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<ir::PrimVarNode>() || obj->IsInstance<ir::BufferNode>();
});
int n_args = func->params.size();
// Step 1. Handle `func->params`
int default_begin_pos = func->params.size() - func->default_params.size();
Array<AssignDoc> args;
args.reserve(n_args);
for (int i = 0; i < n_args; ++i) {
ir::PrimVar var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
Optional<ExprDoc> rhs = NullOpt;
if (i >= default_begin_pos) {
int def_pos = i - default_begin_pos;
rhs = d->AsDoc<ExprDoc>(func->default_params[def_pos],
p->Attr("default_params")->ArrayIndex(def_pos));
}
args.push_back(AssignDoc(DefineVar(var, *f, d), rhs, a));
}
// Step 2. Handle `func->attrs`
if (func->attrs.defined() && !func->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(Dialect(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
}
// Step 3. Handle `func->body`
AsDocBody(func->body, p->Attr("body"), f->get(), d);
Optional<ExprDoc> ret_type = NullOpt;
ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type"));
return FunctionDoc(
/*name=*/IdDoc(func->GetGlobalName()),
/*args=*/args,
/*decorators=*/{Dialect(d, "kernel")},
/*return_type=*/ret_type,
/*body=*/(*f)->stmts);
});

/******************************************************************************
* Function
*****************************************************************************/
Expand Down Expand Up @@ -261,6 +309,49 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ", " << node->type_params << ", " << node->attrs << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::Function>("", [](ir::Function func, ObjectPath p, IRDocsifier d) -> Doc {
With<IRFrame> f(d, func);
(*f)->AddDispatchToken(d, "ir");
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<ir::PrimVarNode>() || obj->IsInstance<ir::HLOVarNode>() ||
obj->IsInstance<ir::BufferNode>();
});
int n_args = func->params.size();
// Step 1. Handle `func->params`
int default_begin_pos = func->params.size() - func->default_params.size();
Array<AssignDoc> args;
args.reserve(n_args);
for (int i = 0; i < n_args; ++i) {
ir::BaseExpr var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
ExprDoc a = d->AsDoc<ExprDoc>(var->checked_type_, var_p->Attr("checked_type_"));
Optional<ExprDoc> rhs = NullOpt;
if (i >= default_begin_pos) {
int def_pos = i - default_begin_pos;
rhs = d->AsDoc<ExprDoc>(func->default_params[def_pos],
p->Attr("default_params")->ArrayIndex(def_pos));
}
args.push_back(AssignDoc(DefineVar(var, *f, d), rhs, a));
}
// Step 2. Handle `func->attrs`
if (func->attrs.defined() && !func->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(Dialect(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
}
// Step 3. Handle `func->body`
AsDocBody(func->body, p->Attr("body"), f->get(), d);
Optional<ExprDoc> ret_type = NullOpt;
ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type"));
return FunctionDoc(
/*name=*/IdDoc(func->GetGlobalName()),
/*args=*/args,
/*decorators=*/{Dialect(d, "script")},
/*return_type=*/ret_type,
/*body=*/(*f)->stmts);
});

/******************************************************************************
* LambdaFunction
*****************************************************************************/
Expand Down Expand Up @@ -346,6 +437,33 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->body << ", " << node->captures << ", " << node->attrs << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::LambdaFunction>(
"", [](ir::LambdaFunction func, ObjectPath p, IRDocsifier d) -> Doc {
With<IRFrame> f(d, func);
(*f)->AddDispatchToken(d, "ir");
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<ir::PrimVarNode>() || obj->IsInstance<ir::HLOVarNode>() ||
obj->IsInstance<ir::BufferNode>();
});
int n_args = func->params.size();
// Step 1. Handle `func->params`
Array<IdDoc> args;
args.reserve(n_args);
for (int i = 0; i < n_args; ++i) {
ir::BaseExpr var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
IdDoc a = d->AsDoc<IdDoc>(var, var_p);
args.push_back(a);
}
// TODO: fix lambda doc
// Step 2. Handle `func->body`
auto body = d->AsDoc<ExprDoc>(func->body, p->Attr("body"));
return LambdaDoc(
/*args=*/args,
/*body=*/body);
});

/******************************************************************************
* BaseFunc
*****************************************************************************/
Expand Down
65 changes: 64 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <unordered_set>

#include <matxscript/ir/_base/structural_equal.h>
#include <matxscript/ir/_base/with.h>
#include <matxscript/runtime/registry.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
Expand All @@ -43,11 +44,16 @@
#include <matxscript/ir/analysis.h>
#include <matxscript/ir/expr_functor.h>
// clang-format on
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>

namespace matxscript {
namespace ir {

using namespace runtime;
using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

IRModule::IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, ClassType> type_definitions,
Expand Down Expand Up @@ -410,5 +416,62 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "IRModule(" << node->functions << ")";
});

struct SortableFunction {
int priority;
GlobalVar gv;
BaseFunc func;

explicit SortableFunction(const std::pair<GlobalVar, BaseFunc>& obj)
: priority(0), gv(obj.first), func(obj.second) {
if (gv->name_hint == "main") {
priority = 1000;
} else if (obj.second->GetTypeKey() == "ir.PrimFunc") {
priority = 1;
} else if (obj.second->GetTypeKey() == "ir.LambdaFunction") {
priority = 2;
} else if (obj.second->GetTypeKey() == "ir.Function") {
priority = 3;
} else {
MXLOG(FATAL) << "TypeError: MATX cannot print functions of type: "
<< obj.second->GetTypeKey();
}
}

bool operator<(const SortableFunction& other) const {
if (this->priority != other.priority) {
return this->priority < other.priority;
}
return this->gv->name_hint < other.gv->name_hint;
}
};

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<IRModule>("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc {
std::vector<SortableFunction> functions;
for (const auto& kv : mod->functions) {
functions.push_back(SortableFunction(kv));
}
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d, ObjectRef{nullptr});
(*f)->AddDispatchToken(d, "ir");
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
d->cfg->binding_names.push_back(gv->name_hint);
Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
(*f)->stmts.back()->source_paths = std::move(doc->source_paths);
} else if (const auto* stmt = doc.as<StmtDocNode>()) {
(*f)->stmts.push_back(GetRef<StmtDoc>(stmt));
} else {
(*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
}
}
// TODO: use ModuleDoc instead
return ClassDoc(IdDoc("Module"), {Dialect(d, "ir_module")}, (*f)->stmts);
});

} // namespace ir
} // namespace matxscript
10 changes: 10 additions & 0 deletions src/ir/op_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

#include <matxscript/ir/_base/attr_registry.h>
#include <matxscript/ir/op_expr.h>
#include <matxscript/ir/printer/doc.h>
#include <matxscript/ir/printer/ir_docsifier.h>
#include <matxscript/ir/printer/ir_frame.h>
#include <matxscript/ir/printer/utils.h>
#include <matxscript/ir/type.h>
#include <matxscript/runtime/container.h>
#include <matxscript/runtime/object_internal.h>
Expand All @@ -38,6 +42,7 @@ namespace matxscript {
namespace ir {

using namespace ::matxscript::runtime;
using namespace ::matxscript::ir::printer;

using OpRegistry = AttrRegistry<OpRegEntry, Op>;

Expand Down Expand Up @@ -125,6 +130,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Op(" << node->name << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
return Dialect(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
});

MATXSCRIPT_REGISTER_GLOBAL("ir._call_builtin_op").set_body([](PyArgs args) -> RTValue {
Type ret_type = args[0].As<Type>();
StringRef op_name = args[1].As<StringRef>();
Expand Down
6 changes: 3 additions & 3 deletions src/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<AllocaVarStmt>("", [](AllocaVarStmt s, ObjectPath p, IRDocsifier d) -> Doc {
ObjectPath var_attr = p->Attr("var");
auto lhs = d->AsDoc<ExprDoc>(s->var, var_attr);
ObjectPath var_p = p->Attr("var");
auto lhs = d->AsDoc<ExprDoc>(s->var, var_p);
Optional<ExprDoc> rhs;
Optional<ExprDoc> annotation;
if (s->init_value.defined()) {
rhs = d->AsDoc<ExprDoc>(s->init_value, p->Attr("init_value"));
}
annotation = LiteralDoc::HLOType(s->var->checked_type_, var_attr->Attr("checked_type_"));
annotation = d->AsDoc<ExprDoc>(s->var->checked_type_, var_p->Attr("checked_type_"));
return AssignDoc(lhs, rhs, annotation);
});

Expand Down
Loading