From b3e6636b430991447433456bd607ca406949eccf Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 9 Nov 2021 08:44:21 -0800 Subject: [PATCH] Better host handling in CompilationConfig & debug printing (#9460) (This is a bit of a grab bag in preparation for #9326 which I'm trying to minimize) While switching the device planner to use SEScopes I had a lot of trouble with Target's not matching up. - If no explicit host target is given but the given TargetMap has targets with hosts, try to use those to establish the host_target. - Make sure both the 'legacy' TargetMap representation and the newer representation agree to pointer equality on their targets. - Make sure the Interpreter uses the target from CompilationConfig since it's been normalized. To debug the above: - When in pretty printing with show_meta_data_ false give as much detail on SEScopes, Targets and call attributes as possible. That needed some rework in the relay_text_printer.cc. - Ditto for critical 'target' attribute on PrimFuncs. - Also added a Target::ToDebugString so I could see the host fields along with everything else since a lot of problems were caused by a mismatch of 'the same' Target with and without a host. (Tried using that for the ReprPrinter but broken unit tests.) Note that the codebase assumes Targets are compared by ObjectPtrEquality, yet CheckAndUpdateHostConsistency (I count 65 call sites) changes the targets. Ultimately CompilationConfig or it's ultimate replacement should ensure we munge targets only once at the 'main' entry points. --- include/tvm/target/target.h | 9 + src/parser/parser.cc | 3 +- src/printer/relay_text_printer.cc | 173 +++++++++++++------- src/printer/text_printer.cc | 1 + src/printer/text_printer.h | 38 ++++- src/printer/tir_text_printer.cc | 3 + src/relay/backend/interpreter.cc | 78 +++++---- src/target/compilation_config.cc | 76 +++++---- src/target/se_scope.cc | 10 +- src/target/target.cc | 43 +++++ tests/cpp/target/compilation_config_test.cc | 38 ++++- tests/python/relay/test_ir_text_printer.py | 2 +- 12 files changed, 340 insertions(+), 134 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index e0d34c87dda7..21760bdc8dbf 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -66,6 +66,15 @@ class TargetNode : public Object { /*! \return The Optional typed target host of the TargetNode */ TVM_DLL Optional GetHost() const; + /*! + * \brief Returns a human readable representation of \p Target which includes all fields, + * especially the host. Useful for diagnostic messages and debugging. + * + * TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently + * code depends on str() and << being the same. + */ + String ToDebugString() const; + void VisitAttrs(AttrVisitor* v) { v->Visit("kind", &kind); v->Visit("tag", &tag); diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 486799603354..092d5b61eeec 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1955,7 +1955,8 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr") TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() { return CreateModulePass( [](const IRModule& mod, const PassContext& ctx) { - auto text = AsText(mod, true); + String text = AsText(mod, /*show_meta_data=*/true); + VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; return ParseModule("GeneratedSource", text); }, 0, "AnnotateSpans", {}); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 9eca038e5c93..7454cfdf336e 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -37,6 +37,7 @@ #include #include #include +#include #include #include "../ir/attr_functor.h" @@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { return PrintPattern(Downcast(node), meta); } else if (node.as()) { return PrintMod(Downcast(node)); - } else if (!show_meta_data_ && node.as()) { - // Show attributes in readable form. - return PrintAttrs(Downcast(node)); } else { // default module. std::ostringstream os; @@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { for (Var param : fn->params) { params.push_back(AllocVar(param)); } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + for (const Doc& d : PrintDictAttrs(fn->attrs)) { params.push_back(d); } doc << Doc::Concat(params) << ") "; @@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { Doc doc; doc << "Tensor[("; std::vector shapes; - for (ObjectRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); + for (const PrimExpr& prim_expr : node->shape) { + // Though not bound within an attribute the attribute visitor will handle the PrimExprs we + // care about. + shapes.push_back(PrintAttributeValue(prim_expr)); } doc << Doc::Concat(shapes); return doc << "), " << PrintDType(node->dtype) << "]"; @@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { // Overload of Attr printing functions //------------------------------------ -Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { - if (value.defined()) { - Doc printed_attr; - if (value.as()) { - printed_attr << "?"; - } else if (auto str_obj = value.as()) { - printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (meta) { - printed_attr = meta_->GetMetaNode(Downcast(value)); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; - } else { - return Doc::Text("None"); - } -} - Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), /*meta=*/true); + // Since we don't have any overload for a specific attribute type we'll need to force + // the meta[...] representation to avoid infinite regress. + return PrintAttributeValue(GetRef(op), /*force_meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { Doc doc; doc << "["; std::vector arr_vals; - for (auto val : *op) { - arr_vals.push_back(PrintAttr(val)); + for (const auto& val : *op) { + arr_vals.push_back(PrintAttributeValue(val)); } doc << Doc::Concat(arr_vals); doc << "]"; @@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { doc << key << "=" << *value << "f"; docs->push_back(doc); } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } void Visit(const char* key, int* value) final { PrintKV(key, *value); } @@ -844,7 +829,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { LOG(FATAL) << "do not allow NDarray as argument"; } void Visit(const char* key, runtime::ObjectRef* obj) final { - PrintKV(key, parent_->PrintAttr(*obj)); + PrintKV(key, parent_->PrintAttributeValue(*obj)); } private: @@ -852,50 +837,126 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { RelayTextPrinter* parent_; }; -Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) { - std::vector docs; - AttrPrinter printer(&docs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - Doc doc; - doc << "{" << Doc::Concat(docs) << "}"; - - return doc; +void RelayTextPrinter::AppendGenericAttrs(std::vector* docs, const Attrs& attrs, + bool include_type_key) { + if (!attrs.defined()) { + return; + } + AttrPrinter printer(docs, this); + // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this + // case we are read-only. + const_cast(attrs.get())->VisitNonDefaultAttrs(&printer); + if (include_type_key) { + std::string s = attrs->GetTypeKey(); + printer.Visit("attrs_type_key", &s); + } } std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; - if (!attrs.defined()) return docs; + if (!attrs.defined()) { + return docs; + } const auto* op_node = op.as(); if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) { - // fallback + // The parser can only understand calls with attributes if they match the operator's + // declared attribute type. If that's not the case fall back to the meta[...] representation. + docs.push_back(meta_->GetMetaNode(attrs)); + } else { + AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node); + } + return docs; +} + +std::vector RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) { + if (!dict_attrs.defined()) { + return {}; + } + return PrintDictAttrs(dict_attrs->dict); +} + +std::vector RelayTextPrinter::PrintDictAttrs(const Map& dict_attrs) { + std::vector docs; + if (!dict_attrs.defined()) { + return docs; + } + for (const auto& k : dict_attrs) { Doc doc; - doc << meta_->GetMetaNode(attrs); + doc << k.first << "=" << PrintAttributeValue(k.second); docs.push_back(doc); - return docs; - } else { - // Show attributes in readable form. - AttrPrinter printer(&docs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - if (!op_node) { - // print call attr type key to restore expr for relay parser - std::string s = std::string(attrs->GetTypeKey()); - printer.Visit("attrs_type_key", &s); + } + return docs; +} + +Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) { + if (value.defined()) { + Doc printed_attr; + if (value.as()) { + printed_attr << "?"; + } else if (auto str_obj = value.as()) { + printed_attr << Doc::StrLiteral(GetRef(str_obj)); + } else if (force_meta) { + printed_attr = meta_->GetMetaNode(Downcast(value)); + } else if (const auto* se_scope_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(se_scope_node)); + } else { + // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while + // debugging. + std::ostringstream os; + os << GetRef(se_scope_node); + return Doc::Text(os.str()); + } + } else if (const auto* base_attr_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(base_attr_node)); + } else { + // Special case: The non-meta form for attributes are much easier to work with while + // debugging. + printed_attr = PrintAttrsAsAttributeValue(GetRef(base_attr_node)); + } + } else if (const auto* base_map_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(base_map_node)); + } else { + // Special case: Show maps fields as key=value pairs to help debugging. + printed_attr << PrintMapAsAttributeValue(GetRef>(base_map_node)); + } + } else if (const auto* global_var_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(global_var_node)); + } else { + printed_attr << "'" << global_var_node->name_hint << "'"; + } + } else { + printed_attr = VisitAttr(value); } - return docs; + return printed_attr; + } else { + return Doc::Text("None"); } } -std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { +Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) { std::vector docs; - if (!attrs.defined()) return docs; - const auto* dict_attrs = attrs.as(); - ICHECK(dict_attrs); - for (const auto& k : dict_attrs->dict) { + AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false); + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + return doc; +} + +Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& map) { + std::vector docs; + for (const auto& k : map) { Doc doc; - doc << k.first << "=" << Print(k.second); + doc << PrintAttributeValue(k.first); + doc << "="; + doc << PrintAttributeValue(k.second); docs.push_back(doc); } - return docs; + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + return doc; } Doc RelayTextPrinter::PrintSpan(const Span& span) { diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index b8533a5d8801..444cb0828c94 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -58,6 +58,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { os << "def @" << kv.first->name_hint; doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); } else if (kv.second.as()) { + doc << "@" << kv.first->name_hint << " = "; doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); } doc << Doc::NewLine(); diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 316d59631782..ebd667ae2ac7 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -77,9 +77,42 @@ class RelayTextPrinter : public ExprFunctor, // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const ObjectRef& node); Doc PrintFinal(const ObjectRef& node); - Doc PrintAttrs(const Attrs& attrs); + + /*! + * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence + * of key=value entries, if any. + */ + void AppendGenericAttrs(std::vector* docs, const Attrs& attrs, bool include_type_key); + + /*! + * \brief Returns \p attrs printed as a sequence of key=value entries, if any. + * This is used for call attributes. + */ std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); - std::vector PrintFuncAttrs(const Attrs& attrs); + + /*! + * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any. + * This is used for function definition attributes. + */ + std::vector PrintDictAttrs(const DictAttrs& dict_attrs); + std::vector PrintDictAttrs(const Map& dict_attrs); + + /*! + * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta + * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag. + */ + Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false); + + /*! + * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces. + */ + Doc PrintAttrsAsAttributeValue(const Attrs& attrs); + + /*! + * \brief Returns \p map printed as a self-contained value, ie wrapped in braces. + */ + Doc PrintMapAsAttributeValue(const Map& map); + Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); @@ -162,7 +195,6 @@ class RelayTextPrinter : public ExprFunctor, //------------------------------------ // Overload of Attr printing functions //------------------------------------ - Doc PrintAttr(const ObjectRef& value, bool meta = false); Doc VisitAttrDefault_(const Object* op) final; Doc VisitAttr_(const ArrayNode* op) final; Doc VisitAttr_(const tir::IntImmNode* op) final; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 302c4491cebe..e479af1b2fe9 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -71,6 +72,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { return PrintString(node.as()); } else if (node->IsInstance()) { return PrintBufferRegion(node.as()); + } else if (node->IsInstance()) { + return Doc::Text(node.as()->ToDebugString()); } else { return this->meta_->GetMetaNode(node); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index a596e09907d5..13b855624461 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include "../op/annotation/annotation.h" #include "../transforms/pass_utils.h" @@ -292,8 +293,11 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(IRModule unified_mod, Device device, Target target) - : unified_mod_(unified_mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {} + Interpreter(IRModule unified_mod, CompilationConfig config, Device device) + : unified_mod_(unified_mod), + config_(std::move(config)), + device_(device), + debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -386,12 +390,12 @@ class Interpreter : public ExprFunctor, per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module); auto mod_itr = per_target_module_std_map.find(target); ICHECK(mod_itr != per_target_module_std_map.end()) - << "No target module for target '" << target->str() << "'"; + << "No target module for target " << target->ToDebugString(); const IRModule& target_module = (*mod_itr).second; for (const auto& var : all_tir_fn_vars) { ICHECK(target_module->ContainGlobalVar(var->name_hint)) - << "No global var for '" << var->name_hint << "' in module for target '" << target->str() - << "'"; + << "No global var for '" << var->name_hint << "' in module for target " + << target->ToDebugString(); lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint)); } @@ -407,8 +411,9 @@ class Interpreter : public ExprFunctor, // Extract all the packed functions. for (const auto& var : all_tir_fn_vars) { PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); - ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint - << "' in compiled module for target '" << target->str() << "'"; + ICHECK(packed_func != nullptr) + << "No packed function for global var '" << var->name_hint + << "' in compiled module for target " << target->ToDebugString(); compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } @@ -734,9 +739,11 @@ class Interpreter : public ExprFunctor, Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); } - return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, target_, - prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, - num_shape_inputs, num_shape_outputs, cpu_target_, args); + ICHECK(config_->optional_homogeneous_target.defined()); + return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, + config_->optional_homogeneous_target, prim_shape_fn_var, + all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, + num_shape_outputs, config_->host_se_scope->target, args); } } @@ -884,13 +891,11 @@ class Interpreter : public ExprFunctor, // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; + /*! \brief Compilation config describing the available targets. */ + CompilationConfig config_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) Device device_; - // Unique target describing how to compile for primitives (but not shape functions). - Target target_; - // Default 'CPU' target for shape primitives. - Target cpu_target_{"llvm"}; // Call stack. Stack stack_; // The distinguished 'debug' operator, which is handled specially. @@ -898,25 +903,21 @@ class Interpreter : public ExprFunctor, }; /*! - * Lowers all calls to primitives in \p mod appropriate for device and target. Returns the + * Lowers all calls to primitives in \p mod appropriate for \p config. Returns the * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -IRModule Prepare(IRModule mod, Device device, Target target) { - // Things to initialize to pass into tec::LowerTEPass - // We only have one device-specific target. - tec::TargetMap targets = {{device.device_type, target}}; - if (device.device_type != kDLCPU) { - // However some primitives (eg dynamic shape functions) must always execute on the CPU, - // so make sure we have a target for that. - targets.emplace(kDLCPU, Target("llvm")); +IRModule Prepare(IRModule mod, CompilationConfig config) { + tec::TargetMap tec_target_map; + for (const auto& pair : config->legacy_target_map) { + tec_target_map.emplace(static_cast(pair.first->value), pair.second); } - // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), // Figure out which devices should be used to execute. - transform::PlanDevices(device.device_type), + // TODO(mbs): Should ignore all existing annotations when constant folding + transform::PlanDevices(config->default_primitive_se_scope->device_type()), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -926,7 +927,8 @@ IRModule Prepare(IRModule mod, Device device, Target target) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(targets, /*module_name=*/"intrp", [](Function func) { /* no-op */ })}); + tec::LowerTEPass(tec_target_map, /*module_name=*/"intrp", + [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); @@ -979,7 +981,15 @@ class NeedsPreparationVisitor : public ExprVisitor { TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, Target target) { VLOG_CONTEXT << "EvalFunction"; - VLOG(1) << "evaling module:\n" << PrettyPrint(mod) << "and expression:\n" << PrettyPrint(expr); + VLOG(1) << "evaling module:" << std::endl + << PrettyPrint(mod) << "and expression:" << std::endl + << PrettyPrint(expr); + + ICHECK_EQ(device.device_type, target->kind->device_type); + TargetMap targets; + targets.Set(device.device_type, target); + CompilationConfig config(transform::PassContext::Current(), targets, + /*optional_host_target_arg=*/{}); // // Step 1: Prepare mod. @@ -1024,9 +1034,9 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - IRModule lowered_mod = Prepare(mod_with_expr, device, target); + IRModule lowered_mod = Prepare(mod_with_expr, config); - std::shared_ptr intrp = std::make_shared(lowered_mod, device, target); + std::shared_ptr intrp = std::make_shared(lowered_mod, config, device); // // Step 2: Evaluate target function to a closure. @@ -1065,12 +1075,18 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { + ICHECK_EQ(device.device_type, target->kind->device_type); + TargetMap targets; + targets.Set(device.device_type, target); + CompilationConfig config(transform::PassContext::Current(), targets, + /*optional_host_target_arg=*/{}); + std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - IRModule mod = Prepare(mod_and_global.first, device, target); + IRModule mod = Prepare(mod_and_global.first, config); - Interpreter intrp(mod, device, target); + Interpreter intrp(mod, config, device); Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint); if (expr.as() == nullptr) { // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index b3491d656625..37f6e1e3d15a 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -61,32 +61,48 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex if (host_target.defined()) { CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; host_device_type = static_cast(host_target->kind->device_type); - if (host_device_type != kDLCPU) { - LOG(WARNING) << "Using the given host target '" << host_target << "' of non-CPU device type " - << host_device_type << " for all host operations and data"; - } else { - LOG(INFO) << "Using the given host target '" << host_target << "' of device type " - << host_device_type << " for all host operations and data"; + LOG(INFO) << "Using the given host target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; + for (const auto& primitive_target : primitive_targets) { + if (primitive_target->host.defined() && + !StructuralEqual()(primitive_target->host, host_target)) { + LOG(WARNING) << "The primitive target " << primitive_target->ToDebugString() + << " already has a host which disagrees with the desired host target. It " + "will be ignored."; + } } + } else if (primitive_targets.size() == 1 && primitive_targets.front()->host.defined()) { + host_target = primitive_targets.front()->GetHost().value(); + CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; + host_device_type = static_cast(host_target->kind->device_type); + LOG(INFO) << "Using the host of the unique primitive target, namely " + << host_target->ToDebugString() << " of device type " << host_device_type + << " for the host target"; } else if (primitive_targets.size() == 1 && primitive_targets.front()->kind->device_type == kDLCPU) { // In the homogenous case without an explicit host target just use the given target so long as - // it's a CPU. However make sure we 'forget' any host it may already have. + // it's a CPU. host_device_type = kDLCPU; - host_target = Target(primitive_targets.front()); - LOG(INFO) << "Using the unique target '" << host_target << "' of device type " - << host_device_type << " for all host operations and data"; + host_target = primitive_targets.front(); + LOG(INFO) << "Using the unique primitive target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; } else { // Fallback. host_device_type = kDLCPU; // Even if the list of available targets already includes one for kDLCPU we won't use it - // since its options may not be appropriate for host code (eg shape functions). Instead, - // create a fresh default Target. + // in the hetrogeneous case since its options may not be appropriate for host code + // (eg shape functions). Instead, create a fresh default Target. host_target = MakeDefaultTarget(host_device_type); - LOG(WARNING) << "Using the default host target '" << host_target << "' of device type " - << host_device_type << " for all host operations and data"; + LOG(WARNING) << "Using the default target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; } ICHECK(host_target.defined()); + ICHECK(!host_target->host.defined()); + + if (host_device_type != kDLCPU) { + // I think we're on thin ice here until we've audited the code base for assumed kDLCPU. + LOG(WARNING) << "The host target is not a CPU."; + } // // Establish the host SEScope. @@ -112,24 +128,19 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); if (opt_fallback_dev) { const int64_t v = opt_fallback_dev.value()->value; - if (v <= 0) { - LOG(FATAL) - << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " - << v; - default_primitive_device_type = kDLCPU; - } else { - default_primitive_device_type = static_cast(v); - LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " - << default_primitive_device_type - << " as the default device type for all primitive operations"; - } + CHECK_GT(v, 0) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " << v; + default_primitive_device_type = static_cast(v); + LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; } else if (primitive_targets.size() == 1) { // In the homogeneous case there's no free choice. default_primitive_device_type = static_cast(primitive_targets.front()->kind->device_type); - LOG(INFO) << "Using the unique target '" << primitive_targets.front() << "' of device type " - << default_primitive_device_type - << " as the default device type for all primitive operations"; + LOG(INFO) << "Using the device type " << default_primitive_device_type + << " of the unique primitive target as the default device type for all primitive " + "operations"; } else { // Fallback. Note that we'll require a primitive Target of kDLCPU device_type to be given // and won't manufacture one out of thin air. @@ -154,6 +165,7 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex return Target("llvm"); } else { // LLVM is not available. + // TODO(mbs): Already deprecated? return Target("stackvm"); } } else { @@ -178,10 +190,10 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, auto node = make_object(); for (const auto& pair : legacy_target_map_arg) { - VLOG(0) << "Available primitive target " << pair.first << " = '" << pair.second << "'"; + VLOG(0) << "Available primitive target " << pair.first << " = " << pair.second->ToDebugString(); } if (optional_host_target_arg.defined()) { - VLOG(0) << "Available host target '" << optional_host_target_arg << "'"; + VLOG(0) << "Available host target " << optional_host_target_arg->ToDebugString(); } // Capture the arguments in our representation. @@ -210,8 +222,8 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, node->primitive_targets.size() == 1 ? *node->primitive_targets.begin() : Target(); for (const auto& target : node->primitive_targets) { - LOG(INFO) << "Target '" << target << "' of device type " << target->kind->device_type - << " is available for primitives"; + LOG(INFO) << "Target " << target->ToDebugString() << " of device type " + << target->kind->device_type << " is available for primitives"; } LOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope; LOG(INFO) << "Using host scope " << node->host_se_scope; diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc index 150a883cb565..95d5a7de5775 100644 --- a/src/target/se_scope.cc +++ b/src/target/se_scope.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (need_sep) { p->stream << ", "; } - p->stream << "target='" << node->target << "'"; + p->stream << "target=" << node->target->ToDebugString(); need_sep = true; } if (!node->memory_scope.empty()) { @@ -62,13 +62,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "memory_scope='" << node->memory_scope << "'"; } } +#if TVM_LOG_DEBUG + // We rely on object identity of SEScopes, so include the object address to help debugging. + p->stream << ", id=" << reinterpret_cast(ref.get()); +#endif p->stream << ")"; }); SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, MemoryScope memory_scope) { ICHECK(!target.defined() || device_type == target->kind->device_type) - << "target '" << target << "' has device type " << target->kind->device_type + << "target " << target->ToDebugString() << " has device type " << target->kind->device_type << " but scope has device type " << device_type; auto node = make_object(); node->device_type_int = device_type; @@ -173,7 +177,7 @@ SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Targ cache_.emplace(prototype); return prototype; } else { - VLOG(1) << "reusing '" << *itr << "' for '" << prototype << "'"; + VLOG(1) << "reusing existing scope " << *itr; ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); if (prototype->target.defined()) { ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); diff --git a/src/target/target.cc b/src/target/target.cc index d1c85c583b3b..6f5e8ee67b30 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -457,6 +457,7 @@ const std::string& TargetNode::str() const { if (Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { os << ' ' << attrs_str.value(); } + str_repr_ = os.str(); } return str_repr_; @@ -531,6 +532,48 @@ Optional TargetNode::GetHost() const { return GetRef>(this->host.as()); } +String TargetNode::ToDebugString() const { + std::ostringstream os; + os << "Target("; + os << "kind='" << kind->name << "'"; + if (!tag.empty()) { + os << ", tag='" << tag << "'"; + } + if (!keys.empty()) { + os << ", keys={"; + bool first = true; + for (const auto& key : keys) { + if (!first) { + os << ", "; + } + os << "'" << key << "'"; + first = false; + } + os << "}"; + } + if (!attrs.empty()) { + os << ", attrs={"; + bool first = true; + for (const auto& pair : attrs) { + if (!first) { + os << ", "; + } + os << '"' << pair.first << "': " << pair.second; + first = false; + } + os << "}"; + } + if (host.defined()) { + os << ", host=" << GetHost().value()->ToDebugString(); + } +#if TVM_LOG_DEBUG + // We depend on pointer equality so include that in the debug representation. + os << ", id=" << reinterpret_cast(this); +#endif + os << ")"; + return os.str(); +} + bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { return equal(kind.get(), other->kind.get()) && equal(host, other->host) && equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index 5c2b7990a498..ae5f5d0c3dc4 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -40,13 +40,13 @@ CompilationConfig TestCompilationConfig() { return CompilationConfig(pass_ctx, legacy_target_map, TestDefaultCpuTarget()); } -TEST(CompilationConfig, Constructor_Homogeneous_DefaultHost) { +TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) { transform::PassContext pass_ctx = transform::PassContext::Create(); Target host_target = TestDefaultCpuTarget(); Target cuda_target = TestCudaTarget(); TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); SEScope expected_default_primitive_se_scope(kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); @@ -68,7 +68,31 @@ TEST(CompilationConfig, Constructor_Homogeneous_DefaultHost) { Target::WithHost(cuda_target, host_target))); } -TEST(CompilationConfig, Constructor_Hetrogeneous_DefaultHost) { +TEST(CompilationConfig, Constructor_Homegenoous_InnerHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestCpuTarget(); + Target cuda_target = Target::WithHost(TestCudaTarget(), host_target); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); +} + +TEST(CompilationConfig, Constructor_Homogenous_CPUHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + EXPECT_TRUE(StructuralEqual()(config->host_target, cpu_target)); + ASSERT_TRUE(config->optional_homogeneous_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, + Target::WithHost(cpu_target, cpu_target))); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { transform::PassContext pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); Target host_target = TestDefaultCpuTarget(); @@ -77,7 +101,7 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_DefaultHost) { TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); SEScope expected_default_primitive_se_scope(kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); @@ -123,7 +147,7 @@ TEST(CompilationConfig, Constructor_InvalidAttribute) { TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); EXPECT_ANY_THROW( - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { @@ -132,7 +156,7 @@ TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); EXPECT_ANY_THROW( - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { @@ -141,7 +165,7 @@ TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); legacy_target_map.Set(Integer(static_cast(kDLExtDev)), TestExtDevTarget()); EXPECT_ANY_THROW( - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } TEST(CompilationConfig, CanonicalSEScope) { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2834bba9248b..21c460fa0371 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -254,7 +254,7 @@ def test_null_attribute(): z = relay.Function([x], y) z = z.with_attr("TestAttribute", None) txt = astext(z) - assert "TestAttribute=(nullptr)" in txt + assert "TestAttribute=None" in txt def test_span():