diff --git a/apps/hannk/interpreter/ops.cpp b/apps/hannk/interpreter/ops.cpp index ebda094a6410..45806f0f6429 100644 --- a/apps/hannk/interpreter/ops.cpp +++ b/apps/hannk/interpreter/ops.cpp @@ -484,24 +484,43 @@ void mul_uint8(const HalideBuffer &in1, const QuantizationInfo &in1q elementwise_loop_nest<2>(mul_rank2, in1, in2, out); } -void requantize(const HalideBuffer &in, const QuantizationInfo &inq, - HalideBuffer out, const QuantizationInfo &outq, - ActivationFunction activation = ActivationFunction::None) { - if (inq == outq) { - // Some of these are just copies, or no-ops. - if (is_alias(in.raw_buffer(), out.raw_buffer())) { - return; - } else { - out.copy_from(in); - } - } else if (in.type() == halide_type_of() && - out.type() == halide_type_of()) { +bool try_requantize(const HalideBuffer &in, const QuantizationInfo &inq, + HalideBuffer out, const QuantizationInfo &outq, + ActivationFunction activation = ActivationFunction::None) { + if (in.type() != out.type()) { + HLOG(ERROR) << "requantize: input and output types must match"; + return false; + } + + if (in.type() == halide_type_of() && + out.type() == halide_type_of()) { // TODO: Maybe a dedicated pipeline for this would be better. It // could be a little faster, and avoid some quantization error. add_uint8(in, inq, 1, in, inq, 0, out, outq, activation); - } else { - HLOG(FATAL) << "Unable to requantize " << in.type() << " -> " << out.type() << "\n"; + return true; + } + + return false; +} + +// Input and output buffer types must match. +// If the input and output buffers are quantized, we always call requantize. +// If not, we simply copy. +bool requantize_or_copy(const HalideBuffer &in, const QuantizationInfo &inq, + HalideBuffer out, const QuantizationInfo &outq, + ActivationFunction activation = ActivationFunction::None) { + if (in.type() != out.type()) { + HLOG(ERROR) << "requantize_or_copy: input and output types must match"; + return false; + } + if (try_requantize(in, inq, out, outq, activation)) { + return true; + } + + if (!is_alias(in.raw_buffer(), out.raw_buffer())) { + out.copy_from(in); } + return true; } ActivationFunction to_activation(UnaryOp::Operator op) { @@ -728,7 +747,9 @@ void ConcatenationOp::execute() { auto output_crop = output_buf; crop_to_union(output_crop, input_buf); - requantize(input_buf, input(i)->quantization(), output_crop, output()->quantization()); + + bool copied = requantize_or_copy(input_buf, input(i)->quantization(), output_crop, output()->quantization()); + HCHECK(copied); } } @@ -1635,7 +1656,8 @@ void SplitOp::execute() { assert(output_buf.dim(axis_).min() == 0); output_buf.translate(axis_, concatenated_i); - requantize(input_buf, input()->quantization(), output_buf, output(i)->quantization()); + bool copied = requantize_or_copy(input_buf, input()->quantization(), output_buf, output(i)->quantization()); + HCHECK(copied); concatenated_i += output_buf.dim(axis_).extent(); } @@ -1785,7 +1807,8 @@ void UnaryOp::execute() { mul_uint8(in_buf, in->quantization(), in_buf, in->quantization(), out_buf, out->quantization()); return; } else if (op_ == Relu || op_ == Relu6 || op_ == ReluN1To1) { - requantize(in_buf, in->quantization(), out_buf, out->quantization(), to_activation(op_)); + bool copied = try_requantize(in_buf, in->quantization(), out_buf, out->quantization(), to_activation(op_)); + HCHECK(copied); return; } }