Skip to content

Commit

Permalink
fixup! [USMP] Adding support for U4 usecase
Browse files Browse the repository at this point in the history
Change-Id: I78f03d36b12b4a5e8eae8d11701f51019489defc
  • Loading branch information
manupak committed Apr 20, 2022
1 parent 8fa92ac commit 10b6504
Showing 1 changed file with 9 additions and 21 deletions.
30 changes: 9 additions & 21 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1112,32 +1112,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
// Extract additional information around main TIR PrimFunc arguments
Array<String> devices = ListDevices();
Array<tir::Var> inputs =
Array<tir::Var>(tir_main_func->params.begin(), tir_main_func->params.begin() +
tir_main_func->params.size() -
return_sid_.size() - devices.size());
const auto main_func_params_end_iterator =
tir_main_func->params.begin() + tir_main_func->params.size();
const auto outputs_begin_iterator =
main_func_params_end_iterator - return_sid_.size() - devices.size();
Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params.begin(), outputs_begin_iterator);
Array<TensorType> input_tensor_types;
for (auto i : inputs) {
input_tensor_types.push_back(io_tensor_types_[i]);
}

Array<tir::Var> outputs =
Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size());
std::vector<String> output_var_names;
if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
Array<String> output_tensor_names = opt.value();
for (size_t i = 0; i < output_tensor_names.size(); ++i) {
output_var_names.push_back(output_tensor_names[i]);
}
}

// If output names have not been specified then generate default output names
if (output_var_names.size() == 0) {
if (return_sid_.size() == 1) {
output_var_names.push_back(String("output"));
} else {
for (size_t i = 0; i < return_sid_.size(); ++i) {
output_var_names.push_back(String("output" + std::to_string(i)));
}
}
for (const tir::Var& output : outputs) {
output_var_names.push_back(output->name_hint);
}

Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};
Expand Down

0 comments on commit 10b6504

Please sign in to comment.