diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index d37fbeabc277..3069342524f5 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -641,6 +641,17 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { } } +TargetMap GetTecTargetMapFromTargetMap(Map targets) { + tec::TargetMap targets_; + for (const auto& it : targets) { + auto dev_type = it.first.as(); + ICHECK(dev_type); + targets_[static_cast(dev_type->value)] = it.second; + } + + return targets_; +} + Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name, TECompiler compiler, std::function process_fn) { runtime::TypedPackedFunc pass_func = diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index d5135e6301c4..291f0ec8b613 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -149,6 +149,14 @@ void UpdateFunctionMetadata(Function relay_func, */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +/*! + * \brief Obtain the TargetMap from the Map. + * + * \param targets + * \return TargetMap + */ +TargetMap GetTecTargetMapFromTargetMap(Map targets); + /*! * \brief Update the "main" control function's metadata * @@ -196,7 +204,7 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \returns The pass which lowers primative functions to TIR + * \returns The pass which lowers primitive functions to TIR */ transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name, std::function process_fn); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 723a0ea6ee7e..6eb7ee6d3c07 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -921,7 +921,6 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe context_.module = OptimizeModule(mod, targets_, target_host_); // Populate the global map. - // // This maps global variables to a global index // in the VMFunction table. PopulateGlobalMap(); @@ -1083,6 +1082,13 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, pass_seqs.push_back(transform::InferType()); pass_seqs.push_back(transform::LabelOps()); + tec::TargetMap targets_; + tec::DeviceMap device_map; + + targets_ = GetTecTargetMapFromTargetMap(targets); + + pass_seqs.push_back(tec::LowerTEPass(targets_, device_map, "vm_mod", [this](Function func) {})); + transform::Sequential seq(pass_seqs); tvm::With ctx(pass_ctx); if (targets.size() == 1) { @@ -1090,6 +1096,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, With tctx((*it).second); return seq(mod); } + return seq(mod); }