diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 1eeef7db5cb5..a9e744cc58d1 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -145,13 +145,15 @@ bool ConcatType(const nnvm::NodeAttrs& attrs, int dtype = -1; // checks uniformity of input - for (int i : *in_type) { + for (size_t i =0; i < in_type->size(); ++i) { if (dtype == -1) { - dtype = i; + dtype = in_type->at(i); } else { - CHECK(i == dtype || - i == -1) << - "Non-uniform data type in Concat"; + CHECK(in_type->at(i) == dtype || in_type->at(i) == -1) + << "Non-uniform data type in " << attrs.op->name + << ", expected data type " << mxnet::op::type_string(dtype) + << ", got data type " << mxnet::op::type_string(in_type->at(i)) + << " for input " << i; } }