diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 094d3fdf51df5..813c217e2b424 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -243,40 +243,48 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ObjectRef value); - LiteralDoc(ObjectRef value, ObjectPath object_path); + explicit LiteralDoc(ObjectRef value, const Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. */ - static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + static LiteralDoc None(const Optional& p) { + return LiteralDoc(ObjectRef(nullptr), p); + } /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ - static LiteralDoc Int(int64_t v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + static LiteralDoc Int(int64_t v, const Optional& p) { + return LiteralDoc(IntImm(DataType::Int(64), v), p); + } /*! * \brief Create a LiteralDoc to represent boolean. * \param v The boolean value. */ - static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + static LiteralDoc Boolean(bool v, const Optional& p) { + return LiteralDoc(IntImm(DataType::Bool(), v), p); + } /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ - static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } + static LiteralDoc Float(double v, const Optional& p) { + return LiteralDoc(FloatImm(DataType::Float(64), v), p); + } /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ - static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + static LiteralDoc Str(const String& v, const Optional& p) { return LiteralDoc(v, p); } /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ - static LiteralDoc DataType(const DLDataType& v) { - return LiteralDoc::Str(runtime::DLDataType2String(v)); + static LiteralDoc DataType(const runtime::DataType& v, const Optional& p) { + std::string dtype = v.is_void() ? "void" : runtime::DLDataType2String(v); + return LiteralDoc::Str(dtype, p); } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index e426946b56fef..e0419b469505e 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -259,11 +259,12 @@ inline void FrameNode::ExitWithScope() { template inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const { - if (!obj.defined()) { - return Downcast(LiteralDoc::None()); + if (obj.defined()) { + Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef(this)); + d->source_paths.push_back(path); + return Downcast(d); } - return Downcast( - IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef(this))); + return Downcast(LiteralDoc::None(path)); } inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) { diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index a93957d3e18f0..5a4a4cd67a72c 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -155,17 +155,37 @@ class LiteralDoc(ExprDoc): value: Union[str, IntImm, FloatImm, None] - def __init__(self, value: Union[str, float, bool, int, None]): + def __init__( + self, + value: Union[str, float, bool, int, None], + path: Optional[ObjectPath] = None, + ): if value is None: - self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore # pylint: disable=no-member + self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone, path) # type: ignore # pylint: disable=no-member elif isinstance(value, str): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore # pylint: disable=no-member + self.__init_handle_by_constructor__( + _ffi_api.LiteralDocStr, # type: ignore # pylint: disable=no-member + value, + path, + ) elif isinstance(value, float): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore # pylint: disable=no-member + self.__init_handle_by_constructor__( + _ffi_api.LiteralDocFloat, # type: ignore # pylint: disable=no-member + value, + path, + ) elif isinstance(value, bool): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore # pylint: disable=no-member + self.__init_handle_by_constructor__( + _ffi_api.LiteralDocBoolean, # type: ignore # pylint: disable=no-member + value, + path, + ) elif isinstance(value, int): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore # pylint: disable=no-member + self.__init_handle_by_constructor__( + _ffi_api.LiteralDocInt, # type: ignore # pylint: disable=no-member + value, + path, + ) else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index f41b40c92cc9c..89f6b7c8b1cfc 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -48,16 +48,12 @@ StmtBlockDoc::StmtBlockDoc(Array stmts) { this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ObjectRef value) { +LiteralDoc::LiteralDoc(ObjectRef value, const Optional& object_path) { ObjectPtr n = make_object(); n->value = value; - this->data_ = std::move(n); -} - -LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) { - ObjectPtr n = make_object(); - n->value = value; - n->source_paths.push_back(object_path); + if (object_path.defined()) { + n->source_paths.push_back(object_path.value()); + } this->data_ = std::move(n); } @@ -250,15 +246,11 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array(LiteralDoc::None); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") - .set_body_typed(LiteralDoc::Int); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean") - .set_body_typed(LiteralDoc::Boolean); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") - .set_body_typed(LiteralDoc::Float); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr") - .set_body_typed(LiteralDoc::Str); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 5cd459be66964..e438919f4b1b2 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -63,26 +63,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { - return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)}); + return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { - return IR("Op")->Call({LiteralDoc::Str(op->name)}); + return IR("Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](TypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc { - return IR("TypeVar")->Call({LiteralDoc::Str(type_var->name_hint), // - LiteralDoc::Str(TypeKind2String(type_var->kind))}); + .set_dispatch("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc { + return IR("TypeVar")->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), // + LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](GlobalTypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc { + "", [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc { return IR("GlobalTypeVar") - ->Call({LiteralDoc::Str(type_var->name_hint), // - LiteralDoc::Str(TypeKind2String(type_var->kind))}); + ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), + LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -94,7 +94,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc { return IR("TensorType") ->Call({d->AsDoc(type->shape, p->Attr("shape")), - LiteralDoc::DataType(type->dtype)}); + LiteralDoc::DataType(type->dtype, p->Attr("dtype"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index bd27921671947..cb78dc3ff5c33 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -24,7 +24,7 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](String s, ObjectPath p, IRDocsifier d) -> Doc { - return LiteralDoc::Str(s); + return LiteralDoc::Str(s, p); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index f264dfee8d504..2909e059f3e3d 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -588,7 +588,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - // TODO(tvm-team) redirect to Text printer once we have a good text format. auto* node = static_cast(ref.get()); (*p) << "PrimFunc(" << node->params << ") "; if (node->attrs.defined()) { diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 069ec7f3ea415..f78e7037c3e06 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -118,16 +118,20 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // lhs.reserve(m); loop_var_doc.reserve(m); std::string binding_type = ""; + Array binding_paths; for (int i : remap_vars_indices) { tir::IterVar iter_var = block->iter_vars[i]; - ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); + ObjectPath iter_var_p = block_p->Attr("iter_vars")->ArrayIndex(i); lhs.push_back(DefineVar(iter_var->var, *frame, d)); loop_var_doc.push_back(d->AsDoc(realize->iter_values[i], realize_p->Attr("iter_values")->ArrayIndex(i))); + binding_paths.push_back(iter_var_p->Attr("iter_type")); binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; } ExprDoc rhs = TIR("axis")->Attr("remap"); - rhs = rhs->Call({LiteralDoc::Str(binding_type), ListDoc(loop_var_doc)}); + ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt); + binding_str->source_paths = std::move(binding_paths); + rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)}); (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt)); remap_vars_indices.clear(); } @@ -198,11 +202,13 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // Array kwargs_values; if (!realize) { kwargs_keys.push_back("no_realize"); - kwargs_values.push_back(LiteralDoc::Boolean(true)); + kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); } - return ScopeDoc( - NullOpt, TIR("block")->Call({LiteralDoc::Str(block->name_hint)}, kwargs_keys, kwargs_values), - (*frame)->stmts); + return ScopeDoc(NullOpt, + TIR("block") // + ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))}, + kwargs_keys, kwargs_values), + (*frame)->stmts); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 126a6e58273f8..b947039b58de9 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -56,7 +56,7 @@ Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape"); // Step 2. Handle `buffer.dtype` if (buffer->dtype != Default::BufferDType()) { - kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype)); + kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); } // Step 3. Handle `buffer.data` implicit_var_def(buffer->data, p->Attr("data"), "data"); @@ -78,20 +78,22 @@ Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, { String scope = buffer.scope(); if (scope != "global") { - kwargs.Set("scope", LiteralDoc::Str(scope)); + kwargs.Set( + "scope", + LiteralDoc::Str(scope, p->Attr("data")->Attr("type_annotation")->Attr("storage_scope"))); } } // Step 7. Handle `buffer.data_alignment` if (buffer->data_alignment != runtime::kAllocAlignment) { - kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment)); + kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, p->Attr("data_alignment"))); } // Step 8. Handle `buffer.offset_factor` if (needs_print_factor || buffer->offset_factor != 1) { - kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor)); + kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor, p->Attr("offset_factor"))); } // Step 9. Handle `buffer.buffer_type` if (buffer->buffer_type != tir::BufferType::kDefault) { - kwargs.Set("type", LiteralDoc::Str("auto")); + kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type"))); } // Step 10. Handle `buffer.axis_separator` if (!buffer->axis_separators.empty()) { @@ -130,7 +132,8 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& const IRDocsifier& d) { Map attrs = BufferAttrs(buffer, p, frame, d); ExprDoc shape = attrs.Get("shape").value(); - ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype)); + ExprDoc dtype = + attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); return TIR("Buffer")->Call({shape, dtype}, {}, {}); } diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 1f2ba97700cbb..6e0cfd420262a 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -24,17 +24,17 @@ namespace tvm { namespace script { namespace printer { -Doc PrintVar(const tir::Var& var, const ObjectPath& p, const IRDocsifier& d) { +Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { if (Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); Type type = var->type_annotation; if (const auto* ptr_type = type.as()) { ICHECK(ptr_type->element_type->IsInstance()); - ExprDoc rhs = d->AsDoc(type, p->Attr("type_annotation")); + ExprDoc rhs = d->AsDoc(type, var_p->Attr("type_annotation")); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } else { - ExprDoc rhs = TIR("var")->Call({LiteralDoc::DataType(var->dtype)}); + ExprDoc rhs = TIR("var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } } @@ -56,13 +56,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::IterVar var, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::IterVar var, ObjectPath var_p, IRDocsifier d) -> Doc { return TIR("iter_var") ->Call({ - d->AsDoc(var->var, p->Attr("var")), - d->AsDoc(var->dom, p->Attr("dom")), - LiteralDoc::Str(IterVarType2String(var->iter_type)), - LiteralDoc::Str(var->thread_tag), + d->AsDoc(var->var, var_p->Attr("var")), + d->AsDoc(var->dom, var_p->Attr("dom")), + LiteralDoc::Str(IterVarType2String(var->iter_type), var_p->Attr("iter_type")), + LiteralDoc::Str(var->thread_tag, var_p->Attr("thread_tag")), }); }); @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { - ExprDoc dtype = LiteralDoc::DataType(cast->dtype); + ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); return TIR("Cast")->Call({dtype, value}); }); @@ -97,20 +97,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Ramp ramp, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Ramp ramp, ObjectPath ramp_p, IRDocsifier d) -> Doc { return TIR("Ramp")->Call({ - d->AsDoc(ramp->base, p->Attr("base")), - d->AsDoc(ramp->stride, p->Attr("stride")), - LiteralDoc::Int(ramp->lanes), + d->AsDoc(ramp->base, ramp_p->Attr("base")), + d->AsDoc(ramp->stride, ramp_p->Attr("stride")), + LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")), }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Broadcast bc, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Broadcast bc, ObjectPath bc_p, IRDocsifier d) -> Doc { return TIR("Broadcast") ->Call({ - d->AsDoc(bc->value, p->Attr("value")), - LiteralDoc::Int(bc->lanes), + d->AsDoc(bc->value, bc_p->Attr("value")), + LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")), }); }); @@ -165,7 +165,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Call call, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc { static const OpAttrMap& op_names = Op::GetAttrMap("TScriptPrinterName"); static const std::unordered_set dtype_first_arg = { @@ -196,7 +196,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } prefix = TIR(name); } else if (const auto* gv = call->op.as()) { - prefix = LiteralDoc::Str(gv->name_hint); + prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); } else { LOG(FATAL) << "call: " << call; } @@ -204,13 +204,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) int n_args = call->args.size(); args.reserve(n_args + 1); if (dtype_first_arg.count(call->op.get())) { - args.push_back(LiteralDoc::DataType(call->dtype)); + args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } for (int i = 0; i < n_args; ++i) { - args.push_back(d->AsDoc(call->args[i], p->Attr("args")->ArrayIndex(i))); + args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); } if (dtype_last_arg.count(call->op.get())) { - args.push_back(LiteralDoc::DataType(call->dtype)); + args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } return prefix->Call(args); }); @@ -227,7 +227,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc init = d->AsDoc(r->init, p->Attr("init")); ExprDoc axis = d->AsDoc(r->axis, p->Attr("axis")); ExprDoc condition = d->AsDoc(r->condition, p->Attr("condition")); - ExprDoc value_index = LiteralDoc::Int(r->value_index); + ExprDoc value_index = LiteralDoc::Int(r->value_index, p->Attr("value_index")); return TIR("reduce")->Call({combiner}, {"source", "init", "axis", "condition", "value_index"}, {source, init, axis, condition, value_index}); LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r; diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index c8e2580f9c6fa..2a81c37061c69 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -23,7 +23,7 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::For loop, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::For loop, ObjectPath loop_p, IRDocsifier d) -> Doc { // Step 1. Check syntactic sugar: `T.grid` std::vector grid; std::unordered_set grid_loop_vars; @@ -55,10 +55,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n; ++i) { const tir::ForNode* loop = grid[i]; lhs.push_back(DefineVar(loop->loop_var, *f, d)); - rhs.push_back(d->AsDoc(loop->extent, p->Attr("extent"))); - p = p->Attr("body"); + rhs.push_back(d->AsDoc(loop->extent, loop_p->Attr("extent"))); + loop_p = loop_p->Attr("body"); } - AsDocBody(grid.back()->body, p, (*f).get(), d); + AsDocBody(grid.back()->body, loop_p, (*f).get(), d); return ForDoc(TupleDoc(lhs), TIR("grid")->Call(rhs), (*f)->stmts); } // Step 3. If not `T.grid`, print loop kind accordingly @@ -68,13 +68,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Optional annotations = NullOpt; Optional thread = NullOpt; if (tir::is_zero(loop->min)) { - max = d->AsDoc(loop->extent, p->Attr("extent")); + max = d->AsDoc(loop->extent, loop_p->Attr("extent")); } else { - min = d->AsDoc(loop->min, p->Attr("min")); - max = d->AsDoc(loop->min + loop->extent, p->Attr("extent")); + min = d->AsDoc(loop->min, loop_p->Attr("min")); + max = d->AsDoc(loop->min + loop->extent, loop_p->Attr("extent")); } if (!loop->annotations.empty()) { - annotations = d->AsDoc(loop->annotations, p->Attr("annotations")); + annotations = d->AsDoc(loop->annotations, loop_p->Attr("annotations")); } ExprDoc prefix{nullptr}; if (loop->kind == tir::ForKind::kSerial) { @@ -91,7 +91,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) prefix = TIR("vectorized"); } else if (loop->kind == tir::ForKind::kThreadBinding) { prefix = TIR("thread_binding"); - thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag); + thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag, + loop_p->Attr("thread_binding")); } else { LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind); } @@ -113,7 +114,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_values.push_back(annotations.value()); } ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values); - AsDocBody(loop->body, p, (*f).get(), d); + AsDocBody(loop->body, loop_p, (*f).get(), d); return ForDoc(lhs, rhs, (*f)->stmts); }); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index ad00c42119f61..1214f822610cc 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -27,26 +27,26 @@ namespace printer { TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IntImm imm, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IntImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == Default::IntDType()) { - return LiteralDoc::Int(imm->value); + return LiteralDoc::Int(imm->value, imm_p->Attr("value")); } else if (dtype == DataType::Bool()) { - return LiteralDoc::Boolean(imm->value); + return LiteralDoc::Boolean(imm->value, imm_p->Attr("value")); } else { return TIR(runtime::DLDataType2String(dtype)) // - ->Call({LiteralDoc::Int(imm->value)}); + ->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value"))}); } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FloatImm imm, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FloatImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == Default::FloatDType()) { - return LiteralDoc::Float(imm->value); + return LiteralDoc::Float(imm->value, imm_p->Attr("value")); } else { return TIR(runtime::DLDataType2String(dtype)) // - ->Call({LiteralDoc::Float(imm->value)}); + ->Call({LiteralDoc::Float(imm->value, imm_p->Attr("value"))}); } }); @@ -65,26 +65,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PointerType ty, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PointerType ty, ObjectPath ty_p, IRDocsifier d) -> Doc { ExprDoc element_type{nullptr}; if (const auto* prim_type = ty->element_type.as()) { - std::string dtype = - prim_type->dtype.is_void() ? "void" : runtime::DLDataType2String(prim_type->dtype); - element_type = LiteralDoc::Str(dtype); + element_type = LiteralDoc::DataType(prim_type->dtype, // + ty_p->Attr("element_type")->Attr("dtype")); } else { - element_type = d->AsDoc(ty->element_type, p->Attr("element_type")); + element_type = d->AsDoc(ty->element_type, ty_p->Attr("element_type")); } if (ty->storage_scope == "") { return TIR("Ptr")->Call({element_type}); } else { - return TIR("Ptr")->Call({element_type, LiteralDoc::Str(ty->storage_scope)}); + return TIR("Ptr")->Call( + {element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](TupleType ty, ObjectPath p, IRDocsifier d) -> Doc { if (ty->fields.empty()) { - return LiteralDoc::None(); + return LiteralDoc::None(p); } return TIR("Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 57b4c695a4eee..7c8d44c10e722 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -173,31 +173,35 @@ bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::Allocate stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); OccurrenceCounter counter(stmt->buffer_var.get()); counter(stmt->body); if (counter.count == 1 && IsAllocateDeclBufferPattern(stmt.get())) { - return d->AsDoc(stmt->body, p->Attr("body")); + return d->AsDoc(stmt->body, stmt_p->Attr("body")); } - String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); Array args; Array kwargs_keys; Array kwargs_values; - args.push_back(d->AsDoc(stmt->extents, p->Attr("extents"))); - args.push_back(LiteralDoc::DataType(stmt->dtype)); - args.push_back(LiteralDoc::Str(storage_scope)); + args.push_back(d->AsDoc(stmt->extents, stmt_p->Attr("extents"))); + args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); + args.push_back(LiteralDoc::Str(tir::GetPtrStorageScope(stmt->buffer_var), + stmt_p + ->Attr("buffer_var") // + ->Attr("type_annotation") + ->Attr("storage_scope"))); if (!tir::is_one(stmt->condition)) { - args.push_back(d->AsDoc(stmt->condition, p->Attr("condition"))); + args.push_back(d->AsDoc(stmt->condition, stmt_p->Attr("condition"))); } if (!stmt->annotations.empty()) { kwargs_keys.push_back("annotations"); - kwargs_values.push_back(d->AsDoc(stmt->annotations, p->Attr("annotations"))); + kwargs_values.push_back( + d->AsDoc(stmt->annotations, stmt_p->Attr("annotations"))); } ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d); With f(d, stmt); ExprDoc rhs = TIR("allocate")->Call(args, kwargs_keys, kwargs_values); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); @@ -215,9 +219,9 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { runtime::DataType dtype = arr.DataType(); for (int i = 0; i < tot_dim; i++) { if (dtype.is_float()) { - result.push_back(LiteralDoc::Float(data_ptr[i])); + result.push_back(LiteralDoc::Float(data_ptr[i], NullOpt)); } else { - result.push_back(LiteralDoc::Int(data_ptr[i])); + result.push_back(LiteralDoc::Int(data_ptr[i], NullOpt)); } if (i == NUM_PRINT) { break; @@ -228,7 +232,7 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tir::AllocateConst stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); Array args; @@ -273,12 +277,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) LOG(FATAL) << "DataType not supported"; } args.push_back(data_doc); - args.push_back(LiteralDoc::DataType(stmt->dtype)); - args.push_back(d->AsDoc(stmt->extents, p->Attr("extents"))); + args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); + args.push_back(d->AsDoc(stmt->extents, stmt_p->Attr("extents"))); ExprDoc rhs = TIR("allocate_const")->Call(args, kwargs_keys, kwargs_values); With f(d, stmt); ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); @@ -323,18 +327,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::AttrStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); Optional rhs = NullOpt; tir::Stmt body = stmt->body; - ObjectPath body_p = p->Attr("body"); + ObjectPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { if (const auto* realize = stmt->body.as()) { if (realize->buffer.same_as(stmt->node)) { - rhs = - DocsifyBufferRealize(realize, - /*value=*/d->AsDoc(stmt->value, p->Attr("value")), - /*p=*/p->Attr("body"), d); + rhs = DocsifyBufferRealize( + realize, + /*value=*/d->AsDoc(stmt->value, stmt_p->Attr("value")), + /*p=*/stmt_p->Attr("body"), d); body = realize->body; body_p = body_p->Attr("body"); } @@ -344,25 +348,28 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (const auto* iter_var = stmt->node.as()) { if (!d->IsVarDefined(iter_var->var)) { // `DefineVar` is not used here because a more specific name is desirable + ObjectPath iter_var_p = stmt_p->Attr("node"); Frame f = FindLowestVarDef(iter_var->var, d).value(); DefineVar(iter_var->var, f, d); f->stmts.push_back( - AssignDoc(d->AsDoc(iter_var->var, p->Attr("node")->Attr("var")), - TIR("env_thread")->Call({LiteralDoc::Str(iter_var->thread_tag)}), // + AssignDoc(d->AsDoc(iter_var->var, iter_var_p->Attr("var")), + TIR("env_thread") + ->Call({LiteralDoc::Str(iter_var->thread_tag, + iter_var_p->Attr("thread_tag"))}), // NullOpt)); } rhs = TIR("launch_thread") ->Call({ - d->AsDoc(iter_var->var, p->Attr("node")), - d->AsDoc(stmt->value, p->Attr("value")), + d->AsDoc(iter_var->var, stmt_p->Attr("node")), + d->AsDoc(stmt->value, stmt_p->Attr("value")), }); } } if (!rhs.defined()) { rhs = TIR("attr")->Call({ - d->AsDoc(stmt->node, p->Attr("node")), - LiteralDoc::Str(stmt->attr_key), - d->AsDoc(stmt->value, p->Attr("value")), + d->AsDoc(stmt->node, stmt_p->Attr("node")), + LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key")), + d->AsDoc(stmt->value, stmt_p->Attr("value")), }); } With f(d, stmt); diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index de9da80140e4a..86ce4e21351fc 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -25,7 +25,6 @@ #include "control_flow_graph.h" #include -#include #include #include #include