Skip to content

Commit

Permalink
[TVMScript] Add ObjectPath to LiteralDoc
Browse files Browse the repository at this point in the history
This PR adds ObjectPath to LiteralDoc to allow integer/float/string/...
literals to have their own object path. This is a final preparation
towards structural error rendering when SEqual fails.
  • Loading branch information
junrushao committed Jan 23, 2023
1 parent cc7def0 commit 9ba9450
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 134 deletions.
26 changes: 17 additions & 9 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,40 +243,48 @@ class LiteralDocNode : public ExprDocNode {
*/
class LiteralDoc : public ExprDoc {
protected:
explicit LiteralDoc(ObjectRef value);
LiteralDoc(ObjectRef value, ObjectPath object_path);
explicit LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path);

public:
/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
*/
static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }
static LiteralDoc None(const Optional<ObjectPath>& p) {
return LiteralDoc(ObjectRef(nullptr), p);
}
/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*/
static LiteralDoc Int(int64_t v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
static LiteralDoc Int(int64_t v, const Optional<ObjectPath>& p) {
return LiteralDoc(IntImm(DataType::Int(64), v), p);
}
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*/
static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }
static LiteralDoc Boolean(bool v, const Optional<ObjectPath>& p) {
return LiteralDoc(IntImm(DataType::Bool(), v), p);
}
/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*/
static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }
static LiteralDoc Float(double v, const Optional<ObjectPath>& p) {
return LiteralDoc(FloatImm(DataType::Float(64), v), p);
}
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*/
static LiteralDoc Str(const String& v) { return LiteralDoc(v); }
static LiteralDoc Str(const String& v, const Optional<ObjectPath>& p) { return LiteralDoc(v, p); }
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*/
static LiteralDoc DataType(const DLDataType& v) {
return LiteralDoc::Str(runtime::DLDataType2String(v));
static LiteralDoc DataType(const runtime::DataType& v, const Optional<ObjectPath>& p) {
std::string dtype = v.is_void() ? "void" : runtime::DLDataType2String(v);
return LiteralDoc::Str(dtype, p);
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,12 @@ inline void FrameNode::ExitWithScope() {

template <class TDoc>
inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const {
if (!obj.defined()) {
return Downcast<TDoc>(LiteralDoc::None());
if (obj.defined()) {
Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this));
d->source_paths.push_back(path);
return Downcast<TDoc>(d);
}
return Downcast<TDoc>(
IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this)));
return Downcast<TDoc>(LiteralDoc::None(path));
}

inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) {
Expand Down
32 changes: 26 additions & 6 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,37 @@ class LiteralDoc(ExprDoc):

value: Union[str, IntImm, FloatImm, None]

def __init__(self, value: Union[str, float, bool, int, None]):
def __init__(
self,
value: Union[str, float, bool, int, None],
path: Optional[ObjectPath] = None,
):
if value is None:
self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone, path) # type: ignore # pylint: disable=no-member
elif isinstance(value, str):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocStr, # type: ignore # pylint: disable=no-member
value,
path,
)
elif isinstance(value, float):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocFloat, # type: ignore # pylint: disable=no-member
value,
path,
)
elif isinstance(value, bool):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocBoolean, # type: ignore # pylint: disable=no-member
value,
path,
)
elif isinstance(value, int):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocInt, # type: ignore # pylint: disable=no-member
value,
path,
)
else:
raise TypeError(f"Unsupported type {type(value)} for LiteralDoc")

Expand Down
26 changes: 9 additions & 17 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,12 @@ StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
this->data_ = std::move(n);
}

LiteralDoc::LiteralDoc(ObjectRef value) {
LiteralDoc::LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
this->data_ = std::move(n);
}

LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
n->source_paths.push_back(object_path);
if (object_path.defined()) {
n->source_paths.push_back(object_path.value());
}
this->data_ = std::move(n);
}

Expand Down Expand Up @@ -250,15 +246,11 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
});

TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
.set_body_typed<LiteralDoc(int64_t)>(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
.set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
.set_body_typed<LiteralDoc(double)>(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr")
.set_body_typed<LiteralDoc(const String&)>(LiteralDoc::Str);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);

TVM_REGISTER_NODE_TYPE(IdDocNode);
TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
Expand Down
18 changes: 9 additions & 9 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<GlobalVar>("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc {
return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)});
return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
return IR("Op")->Call({LiteralDoc::Str(op->name)});
return IR("Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<TypeVar>("", [](TypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc {
return IR("TypeVar")->Call({LiteralDoc::Str(type_var->name_hint), //
LiteralDoc::Str(TypeKind2String(type_var->kind))});
.set_dispatch<TypeVar>("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc {
return IR("TypeVar")->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), //
LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<GlobalTypeVar>( //
"", [](GlobalTypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc {
"", [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc {
return IR("GlobalTypeVar")
->Call({LiteralDoc::Str(type_var->name_hint), //
LiteralDoc::Str(TypeKind2String(type_var->kind))});
->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")),
LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand All @@ -94,7 +94,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<TensorType>("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc {
return IR("TensorType")
->Call({d->AsDoc<ExprDoc>(type->shape, p->Attr("shape")),
LiteralDoc::DataType(type->dtype)});
LiteralDoc::DataType(type->dtype, p->Attr("dtype"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/ir/misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace printer {

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<String>("", [](String s, ObjectPath p, IRDocsifier d) -> Doc {
return LiteralDoc::Str(s);
return LiteralDoc::Str(s, p);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
1 change: 0 additions & 1 deletion src/script/printer/legacy_repr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)

TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
// TODO(tvm-team) redirect to Text printer once we have a good text format.
auto* node = static_cast<const PrimFuncNode*>(ref.get());
(*p) << "PrimFunc(" << node->params << ") ";
if (node->attrs.defined()) {
Expand Down
18 changes: 12 additions & 6 deletions src/script/printer/tir/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,20 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
lhs.reserve(m);
loop_var_doc.reserve(m);
std::string binding_type = "";
Array<ObjectPath> binding_paths;
for (int i : remap_vars_indices) {
tir::IterVar iter_var = block->iter_vars[i];
ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
ObjectPath iter_var_p = block_p->Attr("iter_vars")->ArrayIndex(i);
lhs.push_back(DefineVar(iter_var->var, *frame, d));
loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i],
realize_p->Attr("iter_values")->ArrayIndex(i)));
binding_paths.push_back(iter_var_p->Attr("iter_type"));
binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R";
}
ExprDoc rhs = TIR("axis")->Attr("remap");
rhs = rhs->Call({LiteralDoc::Str(binding_type), ListDoc(loop_var_doc)});
ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt);
binding_str->source_paths = std::move(binding_paths);
rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)});
(*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt));
remap_vars_indices.clear();
}
Expand Down Expand Up @@ -198,11 +202,13 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
Array<ExprDoc> kwargs_values;
if (!realize) {
kwargs_keys.push_back("no_realize");
kwargs_values.push_back(LiteralDoc::Boolean(true));
kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt));
}
return ScopeDoc(
NullOpt, TIR("block")->Call({LiteralDoc::Str(block->name_hint)}, kwargs_keys, kwargs_values),
(*frame)->stmts);
return ScopeDoc(NullOpt,
TIR("block") //
->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))},
kwargs_keys, kwargs_values),
(*frame)->stmts);
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
15 changes: 9 additions & 6 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p,
array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape");
// Step 2. Handle `buffer.dtype`
if (buffer->dtype != Default::BufferDType()) {
kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype));
kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
}
// Step 3. Handle `buffer.data`
implicit_var_def(buffer->data, p->Attr("data"), "data");
Expand All @@ -78,20 +78,22 @@ Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p,
{
String scope = buffer.scope();
if (scope != "global") {
kwargs.Set("scope", LiteralDoc::Str(scope));
kwargs.Set(
"scope",
LiteralDoc::Str(scope, p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
}
}
// Step 7. Handle `buffer.data_alignment`
if (buffer->data_alignment != runtime::kAllocAlignment) {
kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment));
kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, p->Attr("data_alignment")));
}
// Step 8. Handle `buffer.offset_factor`
if (needs_print_factor || buffer->offset_factor != 1) {
kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor));
kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor, p->Attr("offset_factor")));
}
// Step 9. Handle `buffer.buffer_type`
if (buffer->buffer_type != tir::BufferType::kDefault) {
kwargs.Set("type", LiteralDoc::Str("auto"));
kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type")));
}
// Step 10. Handle `buffer.axis_separator`
if (!buffer->axis_separators.empty()) {
Expand Down Expand Up @@ -130,7 +132,8 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame&
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
ExprDoc dtype =
attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Expand Down
Loading

0 comments on commit 9ba9450

Please sign in to comment.