From 99d56571ad20b9d8141e192d13654bea6b8b0f6d Mon Sep 17 00:00:00 2001 From: maxiandi Date: Tue, 21 Mar 2023 16:25:41 +0800 Subject: [PATCH] [IR] Refactor IR Printer: Type/Function/IRModule... to Doc --- include/matxscript/ir/type.h | 4 +- src/ir/attrs.cc | 10 +++ src/ir/function.cc | 118 ++++++++++++++++++++++++++ src/ir/module.cc | 65 ++++++++++++++- src/ir/op_expr.cc | 10 +++ src/ir/stmt.cc | 6 +- src/ir/type.cc | 155 +++++++++++++++++++++++++++++++++++ 7 files changed, 362 insertions(+), 6 deletions(-) diff --git a/include/matxscript/ir/type.h b/include/matxscript/ir/type.h index 5094a60e..15a8ff47 100644 --- a/include/matxscript/ir/type.h +++ b/include/matxscript/ir/type.h @@ -238,7 +238,7 @@ class PointerTypeNode : public TypeNode { } runtime::Unicode GetPythonTypeName() const override { - return U"pointer"; + return U"matx.handle(" + element_type->GetPythonTypeName() + U")"; } static constexpr const char* _type_key = "PointerType"; @@ -1005,7 +1005,7 @@ class IteratorTypeNode : public TypeNode { } runtime::Unicode GetPythonTypeName() const override { - return U"iterator"; + return U"Iterable[" + container_type->GetPythonTypeName() + U"]"; } static constexpr const char* _type_key = "IteratorType"; diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 2ef07db3..5cbdd333 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -26,12 +26,17 @@ #include #include +#include +#include +#include +#include #include namespace matxscript { namespace ir { using namespace ::matxscript::runtime; +using namespace ::matxscript::ir::printer; void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); @@ -57,6 +62,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << op->dict; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc { + return d->AsDoc(attrs->dict, p->Attr("dict")); + }); + MATXSCRIPT_REGISTER_NODE_TYPE(DictAttrsNode); MATXSCRIPT_REGISTER_NODE_TYPE(AttrFieldInfoNode); diff --git a/src/ir/function.cc b/src/ir/function.cc index b7fe90a3..dcfb5b17 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -26,7 +26,12 @@ */ #include +#include #include +#include +#include +#include +#include #include #include @@ -34,6 +39,7 @@ namespace matxscript { namespace ir { using namespace ::matxscript::runtime; +using namespace ::matxscript::ir::printer; bool BaseFuncNode::HasGlobalName() const { auto global_symbol = GetAttr(attr::kGlobalSymbol); @@ -171,6 +177,48 @@ MATXSCRIPT_REGISTER_GLOBAL("ir.PrimFunc") std::move(attrs)); }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { + With f(d, func); + (*f)->AddDispatchToken(d, "ir"); + d->SetCommonPrefix(func, [](const ObjectRef& obj) { + return obj->IsInstance() || obj->IsInstance(); + }); + int n_args = func->params.size(); + // Step 1. Handle `func->params` + int default_begin_pos = func->params.size() - func->default_params.size(); + Array args; + args.reserve(n_args); + for (int i = 0; i < n_args; ++i) { + ir::PrimVar var = func->params[i]; + ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); + Optional rhs = NullOpt; + if (i >= default_begin_pos) { + int def_pos = i - default_begin_pos; + rhs = d->AsDoc(func->default_params[def_pos], + p->Attr("default_params")->ArrayIndex(def_pos)); + } + args.push_back(AssignDoc(DefineVar(var, *f, d), rhs, a)); + } + // Step 2. Handle `func->attrs` + if (func->attrs.defined() && !func->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(Dialect(d, "func_attr") // + ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); + } + // Step 3. Handle `func->body` + AsDocBody(func->body, p->Attr("body"), f->get(), d); + Optional ret_type = NullOpt; + ret_type = d->AsDoc(func->ret_type, p->Attr("ret_type")); + return FunctionDoc( + /*name=*/IdDoc(func->GetGlobalName()), + /*args=*/args, + /*decorators=*/{Dialect(d, "kernel")}, + /*return_type=*/ret_type, + /*body=*/(*f)->stmts); + }); + /****************************************************************************** * Function *****************************************************************************/ @@ -261,6 +309,49 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->type_params << ", " << node->attrs << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::Function func, ObjectPath p, IRDocsifier d) -> Doc { + With f(d, func); + (*f)->AddDispatchToken(d, "ir"); + d->SetCommonPrefix(func, [](const ObjectRef& obj) { + return obj->IsInstance() || obj->IsInstance() || + obj->IsInstance(); + }); + int n_args = func->params.size(); + // Step 1. Handle `func->params` + int default_begin_pos = func->params.size() - func->default_params.size(); + Array args; + args.reserve(n_args); + for (int i = 0; i < n_args; ++i) { + ir::BaseExpr var = func->params[i]; + ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + ExprDoc a = d->AsDoc(var->checked_type_, var_p->Attr("checked_type_")); + Optional rhs = NullOpt; + if (i >= default_begin_pos) { + int def_pos = i - default_begin_pos; + rhs = d->AsDoc(func->default_params[def_pos], + p->Attr("default_params")->ArrayIndex(def_pos)); + } + args.push_back(AssignDoc(DefineVar(var, *f, d), rhs, a)); + } + // Step 2. Handle `func->attrs` + if (func->attrs.defined() && !func->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(Dialect(d, "func_attr") // + ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); + } + // Step 3. Handle `func->body` + AsDocBody(func->body, p->Attr("body"), f->get(), d); + Optional ret_type = NullOpt; + ret_type = d->AsDoc(func->ret_type, p->Attr("ret_type")); + return FunctionDoc( + /*name=*/IdDoc(func->GetGlobalName()), + /*args=*/args, + /*decorators=*/{Dialect(d, "script")}, + /*return_type=*/ret_type, + /*body=*/(*f)->stmts); + }); + /****************************************************************************** * LambdaFunction *****************************************************************************/ @@ -346,6 +437,33 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->body << ", " << node->captures << ", " << node->attrs << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](ir::LambdaFunction func, ObjectPath p, IRDocsifier d) -> Doc { + With f(d, func); + (*f)->AddDispatchToken(d, "ir"); + d->SetCommonPrefix(func, [](const ObjectRef& obj) { + return obj->IsInstance() || obj->IsInstance() || + obj->IsInstance(); + }); + int n_args = func->params.size(); + // Step 1. Handle `func->params` + Array args; + args.reserve(n_args); + for (int i = 0; i < n_args; ++i) { + ir::BaseExpr var = func->params[i]; + ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + IdDoc a = d->AsDoc(var, var_p); + args.push_back(a); + } + // TODO: fix lambda doc + // Step 2. Handle `func->body` + auto body = d->AsDoc(func->body, p->Attr("body")); + return LambdaDoc( + /*args=*/args, + /*body=*/body); + }); + /****************************************************************************** * BaseFunc *****************************************************************************/ diff --git a/src/ir/module.cc b/src/ir/module.cc index b6b71376..c7767323 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -33,6 +33,7 @@ #include #include +#include #include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, @@ -43,11 +44,16 @@ #include #include // clang-format on +#include +#include +#include +#include namespace matxscript { namespace ir { -using namespace runtime; +using namespace ::matxscript::runtime; +using namespace ::matxscript::ir::printer; IRModule::IRModule(Map functions, Map type_definitions, @@ -410,5 +416,62 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "IRModule(" << node->functions << ")"; }); +struct SortableFunction { + int priority; + GlobalVar gv; + BaseFunc func; + + explicit SortableFunction(const std::pair& obj) + : priority(0), gv(obj.first), func(obj.second) { + if (gv->name_hint == "main") { + priority = 1000; + } else if (obj.second->GetTypeKey() == "ir.PrimFunc") { + priority = 1; + } else if (obj.second->GetTypeKey() == "ir.LambdaFunction") { + priority = 2; + } else if (obj.second->GetTypeKey() == "ir.Function") { + priority = 3; + } else { + MXLOG(FATAL) << "TypeError: MATX cannot print functions of type: " + << obj.second->GetTypeKey(); + } + } + + bool operator<(const SortableFunction& other) const { + if (this->priority != other.priority) { + return this->priority < other.priority; + } + return this->gv->name_hint < other.gv->name_hint; + } +}; + +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc { + std::vector functions; + for (const auto& kv : mod->functions) { + functions.push_back(SortableFunction(kv)); + } + std::sort(functions.begin(), functions.end()); + With f(d, ObjectRef{nullptr}); + (*f)->AddDispatchToken(d, "ir"); + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + const BaseFunc& func = entry.func; + d->cfg->binding_names.push_back(gv->name_hint); + Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv)); + d->cfg->binding_names.pop_back(); + if (const auto* stmt_block = doc.as()) { + (*f)->stmts.push_back(stmt_block->stmts.back()); + (*f)->stmts.back()->source_paths = std::move(doc->source_paths); + } else if (const auto* stmt = doc.as()) { + (*f)->stmts.push_back(GetRef(stmt)); + } else { + (*f)->stmts.push_back(Downcast(doc)); + } + } + // TODO: use ModuleDoc instead + return ClassDoc(IdDoc("Module"), {Dialect(d, "ir_module")}, (*f)->stmts); + }); + } // namespace ir } // namespace matxscript diff --git a/src/ir/op_expr.cc b/src/ir/op_expr.cc index b40971ff..4a12ce52 100644 --- a/src/ir/op_expr.cc +++ b/src/ir/op_expr.cc @@ -30,6 +30,10 @@ #include #include +#include +#include +#include +#include #include #include #include @@ -38,6 +42,7 @@ namespace matxscript { namespace ir { using namespace ::matxscript::runtime; +using namespace ::matxscript::ir::printer; using OpRegistry = AttrRegistry; @@ -125,6 +130,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Op(" << node->name << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { + return Dialect(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); + }); + MATXSCRIPT_REGISTER_GLOBAL("ir._call_builtin_op").set_body([](PyArgs args) -> RTValue { Type ret_type = args[0].As(); StringRef op_name = args[1].As(); diff --git a/src/ir/stmt.cc b/src/ir/stmt.cc index a758000e..ef6fced8 100644 --- a/src/ir/stmt.cc +++ b/src/ir/stmt.cc @@ -125,14 +125,14 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](AllocaVarStmt s, ObjectPath p, IRDocsifier d) -> Doc { - ObjectPath var_attr = p->Attr("var"); - auto lhs = d->AsDoc(s->var, var_attr); + ObjectPath var_p = p->Attr("var"); + auto lhs = d->AsDoc(s->var, var_p); Optional rhs; Optional annotation; if (s->init_value.defined()) { rhs = d->AsDoc(s->init_value, p->Attr("init_value")); } - annotation = LiteralDoc::HLOType(s->var->checked_type_, var_attr->Attr("checked_type_")); + annotation = d->AsDoc(s->var->checked_type_, var_p->Attr("checked_type_")); return AssignDoc(lhs, rhs, annotation); }); diff --git a/src/ir/type.cc b/src/ir/type.cc index 76a7bb50..d0c526c8 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -28,13 +28,26 @@ #include #include +#include #include +#include +#include +#include +#include #include namespace matxscript { namespace ir { using namespace ::matxscript::runtime; +using namespace ::matxscript::ir::printer; + +static StringRef GetLiteralRepr(const Type& ty) { + if (auto const* pt = ty.as()) { + return pt->dtype.is_void() ? "void" : runtime::DLDataType2String(pt->dtype); + } + return ty->GetPythonTypeName().encode(); +} bool IsRuntimeDataType(const Type& type) { if (auto* n = type.as()) { @@ -83,6 +96,14 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << node->dtype; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc { + if (ty->dtype == DataType::Int(64) || ty->dtype == DataType::Float(64)) { + return IdDoc(GetLiteralRepr(ty)); + } + return Dialect(d, GetLiteralRepr(ty)); + }); + MATXSCRIPT_REGISTER_GLOBAL("ir.VoidType").set_body_typed([]() { return VoidType(); }); MATXSCRIPT_REGISTER_GLOBAL("ir.IsVoidType").set_body_typed([](const Type& type) { return IsVoidType(type); @@ -107,6 +128,18 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '*'; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](PointerType ty, ObjectPath ty_p, IRDocsifier d) -> Doc { + ExprDoc element_type{nullptr}; + if (const auto* prim_type = ty->element_type.as()) { + element_type = LiteralDoc::DataType(prim_type->dtype, // + ty_p->Attr("element_type")->Attr("dtype")); + } else { + element_type = d->AsDoc(ty->element_type, ty_p->Attr("element_type")); + } + return Dialect(d, "handle")->Call({element_type}); + }); + TypeVar::TypeVar(StringRef name, TypeKind kind, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name); @@ -127,6 +160,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc { + return Dialect(d, "TypeVar") + ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), // + LiteralDoc::Int(var->kind, p->Attr("kind"))}); + }); + GlobalTypeVar::GlobalTypeVar(StringRef name, TypeKind kind, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name); @@ -147,6 +187,15 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", + [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc { + return Dialect(d, "GlobalTypeVar") + ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), + LiteralDoc::Int(var->kind, p->Attr("kind"))}); + }); + FuncType::FuncType(Array arg_types, Type ret_type, Array type_params, @@ -181,6 +230,16 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->ret_type << ", " << node->type_constraints << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { + return Dialect(d, "FuncType") + ->Call({ + d->AsDoc(func_type->type_params, p->Attr("type_params")), + d->AsDoc(func_type->arg_types, p->Attr("arg_types")), + d->AsDoc(func_type->ret_type, p->Attr("ret_type")), + }); + }); + TupleType::TupleType(Array fields, bool is_std_tuple, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -213,6 +272,21 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TupleTypeNode(" << node->fields << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TupleType ty, ObjectPath p, IRDocsifier d) -> Doc { + if (ty->fields.empty()) { + return LiteralDoc::None(p); + } + p = p->Attr("fields"); + int n = ty->fields.size(); + Array elements; + elements.reserve(n); + for (int i = 0; i < n; ++i) { + elements.push_back(d->AsDoc(ty->fields[i], p->ArrayIndex(i))); + } + return TupleDoc(std::move(elements)); + }); + // Range Type RangeType::RangeType(Span span) { ObjectPtr n = make_object(); @@ -233,6 +307,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RangeTypeNode"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](RangeType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc("range"); + }); + // Object Type ObjectType::ObjectType(bool is_view, Span span) { ObjectPtr n = make_object(); @@ -253,6 +332,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "ObjectTypeNode"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ObjectType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc("Any"); + }); + // String Type StringType::StringType(bool is_view, Span span) { ObjectPtr n = make_object(); @@ -273,6 +357,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "StringTypeNode"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](StringType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc("bytes"); + }); + // Unicode Type UnicodeType::UnicodeType(bool is_view, Span span) { ObjectPtr n = make_object(); @@ -293,6 +382,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "UnicodeTypeNode"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](UnicodeType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc("str"); + }); + // List Type ListType::ListType(bool is_full_typed, Type item_type, Span span) { ObjectPtr n = make_object(); @@ -317,6 +411,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "ListTypeNode(" << node->item_type << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ListType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // Dict Type DictType::DictType(bool is_full_typed, Type key_type, Type value_type, Span span) { ObjectPtr n = make_object(); @@ -341,6 +440,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", value_type: " << node->value_type << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](DictType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // Set Type SetType::SetType(bool is_full_typed, Type item_type, Span span) { ObjectPtr n = make_object(); @@ -362,6 +466,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "SetTypeNode(" << node->item_type << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](SetType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // IteratorType IteratorType::IteratorType(Type container_type, Span span) { ObjectPtr n = make_object(); @@ -394,6 +503,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "_Iterator"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](IteratorType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // ExceptionType ExceptionType::ExceptionType(StringRef name, Span span) { ObjectPtr n = make_object(); @@ -413,6 +527,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(node->name); }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ExceptionType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // FileType FileType::FileType(bool binary_mode, Span span) { ObjectPtr n = make_object(); @@ -433,6 +552,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "FileTypeNode(binary_mode=" << node->binary_mode << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](FileType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // TrieType TrieType::TrieType(Span span) { ObjectPtr n = make_object(); @@ -449,6 +573,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TrieType"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TrieType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // UserDataType UserDataType::UserDataType(Span span) { ObjectPtr n = make_object(); @@ -465,6 +594,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "UserDataType"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](UserDataType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // NDArrayType runtime::Unicode NDArrayTypeNode::GetPythonTypeName() const { std::stringstream os; @@ -514,11 +648,21 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RegexType"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](RegexType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { p->stream << "NDArrayType"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](NDArrayType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + MATXSCRIPT_REGISTER_GLOBAL("ir.Type_GetPythonTypeName").set_body_typed([](Type ty) { return ty->GetPythonTypeName(); }); @@ -549,6 +693,12 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "OpaqueObjectType"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", + [](OpaqueObjectType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + // Ref Type RefType::RefType(Type value, Span span) { ObjectPtr n = make_object(); @@ -569,6 +719,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefTypeNode(" << node->value << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](RefType ty, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc(GetLiteralRepr(ty)); + }); + Type InferIteratorValueType(const Type& cons_ty) { if (auto* ptr = cons_ty.as()) { return ptr->item_type;