Skip to content

Commit

Permalink
merge relay buildmodule to codegen build
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 8, 2019
1 parent eb7caa0 commit ce5b485
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 59 deletions.
17 changes: 16 additions & 1 deletion include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,29 @@ TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
* for heterogeneous build.
* \param input The map contains target to a list of lowered functions pairs.
* \param target_host The target for building host code. To use the default,
* pass Target().
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* for heterogeneous build.
* \param input The map contains target string to a list of lowered functions
* pairs.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, Array<LoweredFunc>>& input,
const Target& target_host,
const BuildConfig& config);

class GenericFuncNode;

/*!
Expand Down
31 changes: 22 additions & 9 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
all_names.insert(x->name);
}

Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;

for (const auto& x : funcs) {
Expand All @@ -451,12 +452,12 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
func = ir::ThreadSync(func, "warp");
func = ir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = ir::SplitHostDevice(func);
fhost->push_back(fsplits[0]);
fhost.push_back(fsplits[0]);
for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
fdevice.push_back(*f);
}
} else if (x->func_type == kHostFunc) {
fhost->push_back(x);
fhost.push_back(x);
} else if (x->func_type == kDeviceFunc) {
fdevice.push_back(x);
} else {
Expand Down Expand Up @@ -492,18 +493,18 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< "\n";
}

for (size_t i = 0; i < fhost->size(); ++i) {
auto func = (*fhost)[i];
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::BindDeviceType(func, target->device_type);
func = ir::LowerTVMBuiltin(func);
fhost->Set(i, func);
fhost.Set(i, func);
}

for (size_t i = 0; i < fhost->size(); ++i) {
auto func = (*fhost)[i];
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::LowerIntrin(func, target_host->target_name);
func = ir::CombineContextCall(func);
fhost->Set(i, func);
fhost.Set(i, func);
}
return {fhost, fdevice};
}
Expand All @@ -517,7 +518,7 @@ runtime::Module DeviceBuild(const Array<LoweredFunc>& funcs,
Array<LoweredFunc>* fhost) {
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
auto& fhost = host_dev_funcs[0];
*fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];

if (!fdevice.empty()) {
Expand Down Expand Up @@ -568,6 +569,18 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
return mhost;
}

// Build for heterogeneous execution when target is a string.
runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
const Target& target_host,
const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> updated_input;
for (const auto& it : inputs) {
auto target = Target::create(it.first);
updated_input.Set(target, it.second);
}
return build(updated_input, target_host, config);
}

// Build for homogeneous execution.
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
Expand Down
51 changes: 2 additions & 49 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -594,52 +594,6 @@ class RelayBuildModule : public runtime::ModuleNode {
}
return func;
}
/*!
* \brief Build module given lowered functions for each target
*
* \param lowered_funcs target_str -> Array<LoweredFunc> map
* \param targets Targets map
* \param cfg Building configuration
*/
void BuildModule(const Map<std::string, Array<LoweredFunc> >& lowered_funcs,
const Map<HalideIR::Expr, HalideIR::Expr>& targets,
const BuildConfig& cfg) {
auto target_host = Target::create(cfg_.fallback_device);
for (const auto& kv : lowered_funcs) {
std::unordered_set<std::string> fname_set;
for (auto f : kv.second) {
if (fname_set.count(f->name)) {
LOG(FATAL) << "Duplicate function name "
<< f->name;
}
fname_set.insert(f->name);
}
}
std::unordered_map<std::string, Target> target_map;
for (const auto& kv : lowered_funcs) {
target_map[kv.first] = Target::create(kv.first);
}
Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_module;
for (const auto& kv : lowered_funcs) {
auto target = target_map[kv.first];
auto host_dev_funcs = split_dev_host_funcs(kv.second, target, target_host, cfg);
for (auto f : host_dev_funcs[0]) {
fhost_all.push_back(f);
}
if (host_dev_funcs[1].size()) {
auto mdev = codegen::Build(host_dev_funcs[1], target->str());
device_module.push_back(mdev);
}
}

auto mhost = codegen::Build(fhost_all, target_host->str());

for (auto mdev : device_module) {
mhost.Import(mdev);
}
ret_.mod = mhost;
}

/*!
* \brief Build relay function to runtime module
Expand Down Expand Up @@ -678,9 +632,8 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();

BuildModule(graph_codegen_->GetLoweredFunc(),
device_target,
tvm_cfg_);
auto target_host = Target::create(cfg_.fallback_device);
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_);
}

protected:
Expand Down

0 comments on commit ce5b485

Please sign in to comment.