Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix MXPredReshape in the c_predict_api #11493

Merged
merged 15 commits into from
Aug 14, 2018
5 changes: 3 additions & 2 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
}
sym = nnvm::Symbol::CreateGroup(out_syms);
}
ret->sym = sym;

// load the parameters
std::unordered_map<std::string, NDArray> arg_params, aux_params;
Expand Down Expand Up @@ -214,6 +215,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
}

Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
ret->ctx = ctx;

std::vector<NDArray> arg_arrays, aux_arrays;
for (size_t i = 0; i < arg_shapes.size(); ++i) {
Expand All @@ -231,6 +233,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
aux_arrays.push_back(nd);
}
ret->arg_arrays = arg_arrays;
ret->aux_arrays = aux_arrays;
// bind
{
std::map<std::string, Context> ctx_map;
Expand Down Expand Up @@ -309,7 +312,6 @@ int MXPredReshape(mx_uint num_input_nodes,
<< " shape has been changed, only allow to change the shape of input data.";
}
}
p->arg_arrays.clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent change. This follows a good design paradigm.


for (size_t i=0; i < aux_names.size(); ++i) {
TShape newShape = aux_shapes[i];
Expand All @@ -319,7 +321,6 @@ int MXPredReshape(mx_uint num_input_nodes,
<< " shape has been changed, only allow to change the shape of input data.";
}
ret->aux_arrays = p->aux_arrays;
p->aux_arrays.clear();

// bind
{
Expand Down