Skip to content

Commit

Permalink
Better host handling in CompilationConfig etc
Browse files Browse the repository at this point in the history
(This is in preparation for apache#9326, which I'm trying to
make as small as possible, sorry for the scatter gun.)

If no explicit host target is given but the given
TargetMap has targets with hosts, try to use those
to establish the host_target.

Also make sure both the 'legacy' TargetMap representation
and the newer representation agree to pointer equality on
their targets.

That triggered a small change in the Interpreter to
make better use of the CompilationConfig.

Since Targets are used in ObjectPtrEquality maps AND we
tend to call CheckAndUpdateHostConsistency all over the
place (I count 65) I had a tricky time debugging failures. Added
a ToDebugString() to Target which will include the host,
and made sure the pretty printer will use the debug-friendly
form when the show_meta_data_ flag is false.
  • Loading branch information
mbs-octoml committed Nov 5, 2021
1 parent 048994b commit 63f1375
Show file tree
Hide file tree
Showing 9 changed files with 295 additions and 133 deletions.
9 changes: 9 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ class TargetNode : public Object {
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;

/*!
* \brief Returns a human readable representation of \p Target which includes all fields,
* especially the host. Useful for diagnostic messages and debugging.
*
* TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently
* code depends on str() and << being the same.
*/
String ToDebugString() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("tag", &tag);
Expand Down
3 changes: 2 additions & 1 deletion src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1955,7 +1955,8 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
return CreateModulePass(
[](const IRModule& mod, const PassContext& ctx) {
auto text = AsText(mod, true);
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
},
0, "AnnotateSpans", {});
Expand Down
139 changes: 83 additions & 56 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/target/se_scope.h>
#include <tvm/tir/function.h>

#include "../ir/attr_functor.h"
Expand Down Expand Up @@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node));
} else if (!show_meta_data_ && node.as<BaseAttrsNode>()) {
// Show attributes in readable form.
return PrintAttrs(Downcast<Attrs>(node));
} else {
// default module.
std::ostringstream os;
Expand Down Expand Up @@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
params.push_back(d);
}
doc << Doc::Concat(params) << ") ";
Expand Down Expand Up @@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
Doc doc;
doc << "Tensor[(";
std::vector<Doc> shapes;
for (ObjectRef shape : node->shape) {
shapes.push_back(PrintAttr(shape));
for (const PrimExpr& prim_expr : node->shape) {
// Though not bound within an attribute the attribute visitor will handle the PrimExprs we
// care about.
shapes.push_back(PrintAttributeValue(prim_expr));
}
doc << Doc::Concat(shapes);
return doc << "), " << PrintDType(node->dtype) << "]";
Expand Down Expand Up @@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
// Overload of Attr printing functions
//------------------------------------

Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
if (value.defined()) {
Doc printed_attr;
if (value.as<tvm::tir::AnyNode>()) {
printed_attr << "?";
} else if (auto str_obj = value.as<tvm::StringObj>()) {
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else {
printed_attr = VisitAttr(value);
}
return printed_attr;
} else {
return Doc::Text("None");
}
}

Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
return PrintAttr(GetRef<ObjectRef>(op), /*meta=*/true);
// Since we don't have any overload for a specific attribute type we'll need to force
// the meta[...] representation to avoid infinite regress.
return PrintAttributeValue(GetRef<ObjectRef>(op), /*force_meta=*/true);
}

Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
Doc doc;
doc << "[";
std::vector<Doc> arr_vals;
for (auto val : *op) {
arr_vals.push_back(PrintAttr(val));
for (const auto& val : *op) {
arr_vals.push_back(PrintAttributeValue(val));
}
doc << Doc::Concat(arr_vals);
doc << "]";
Expand Down Expand Up @@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
doc << key << "=" << *value << "f";
docs->push_back(doc);
}

void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); }
void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); }
void Visit(const char* key, int* value) final { PrintKV(key, *value); }
Expand All @@ -844,60 +829,102 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::ObjectRef* obj) final {
PrintKV(key, parent_->PrintAttr(*obj));
PrintKV(key, parent_->PrintAttributeValue(*obj));
}

private:
std::vector<Doc>* docs;
RelayTextPrinter* parent_;
};

Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) {
std::vector<Doc> docs;
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
Doc doc;
doc << "{" << Doc::Concat(docs) << "}";

return doc;
void RelayTextPrinter::AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs,
bool include_type_key) {
if (!attrs.defined()) {
return;
}
AttrPrinter printer(docs, this);
// Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this
// case we are read-only.
const_cast<BaseAttrsNode*>(attrs.get())->VisitNonDefaultAttrs(&printer);
if (include_type_key) {
std::string s = attrs->GetTypeKey();
printer.Visit("attrs_type_key", &s);
}
}

std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
if (!attrs.defined()) {
return docs;
}
const auto* op_node = op.as<OpNode>();
if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) {
// fallback
Doc doc;
doc << meta_->GetMetaNode(attrs);
docs.push_back(doc);
return docs;
// The parser can only understand calls with attributes if they match the operator's
// declared attribute type. If that's not the case fall back to the meta[...] representation.
docs.push_back(meta_->GetMetaNode(attrs));
} else {
// Show attributes in readable form.
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
if (!op_node) {
// print call attr type key to restore expr for relay parser
std::string s = std::string(attrs->GetTypeKey());
printer.Visit("attrs_type_key", &s);
}
return docs;
AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node);
}
return docs;
}

std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* dict_attrs = attrs.as<DictAttrsNode>();
ICHECK(dict_attrs);
if (!dict_attrs.defined()) {
return docs;
}
for (const auto& k : dict_attrs->dict) {
Doc doc;
doc << k.first << "=" << Print(k.second);
doc << k.first << "=" << PrintAttributeValue(k.second);
docs.push_back(doc);
}
return docs;
}

Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) {
if (value.defined()) {
Doc printed_attr;
if (value.as<tvm::tir::AnyNode>()) {
printed_attr << "?";
} else if (auto str_obj = value.as<tvm::StringObj>()) {
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (force_meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else if (const auto* se_scope_node = value.as<SEScopeNode>()) {
if (show_meta_data_) {
printed_attr = meta_->GetMetaNode(GetRef<SEScope>(se_scope_node));
} else {
// Special case: The ReprPrinter for SEScopeNodes is much easier to work with while
// debugging.
std::ostringstream os;
os << GetRef<SEScope>(se_scope_node);
return Doc::Text(os.str());
}
} else if (const auto* base_attr_node = value.as<BaseAttrsNode>()) {
if (show_meta_data_) {
printed_attr = meta_->GetMetaNode(GetRef<Attrs>(base_attr_node));
} else {
// Special case: The non-meta form for attributes are much easier to work with while
// debugging.
printed_attr = PrintAttrsAsAttributeValue(GetRef<Attrs>(base_attr_node));
}
} else {
printed_attr = VisitAttr(value);
}
return printed_attr;
} else {
return Doc::Text("None");
}
}

Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) {
std::vector<Doc> docs;
AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false);
Doc doc;
doc << "{" << Doc::Concat(docs) << "}";
return doc;
}

Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
const auto* span_node = span.as<SpanNode>();
Expand Down
32 changes: 29 additions & 3 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,36 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
// numbers to be reused and prevents hoisted vars from escaping too far
Doc PrintScope(const ObjectRef& node);
Doc PrintFinal(const ObjectRef& node);
Doc PrintAttrs(const Attrs& attrs);

/*!
* \brief Returns \p attrs printed using the generic attribute visitor, as a sequence
* of key=value entries, if any.
*/
void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key);

/*!
* \brief Returns \p attrs printed as a sequence of key=value entries, if any.
* This is used for call attributes.
*/
std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);

/*!
* \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any.
* This is used for function definition attributes.
*/
std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs);

/*!
* \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta
* is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag.
*/
Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false);

/*!
* \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces.
*/
Doc PrintAttrsAsAttributeValue(const Attrs& attrs);

Doc PrintSpan(const Span& span);

Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
Expand Down Expand Up @@ -162,7 +189,6 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
Doc PrintAttr(const ObjectRef& value, bool meta = false);
Doc VisitAttrDefault_(const Object* op) final;
Doc VisitAttr_(const ArrayNode* op) final;
Doc VisitAttr_(const tir::IntImmNode* op) final;
Expand Down
Loading

0 comments on commit 63f1375

Please sign in to comment.