Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] F2qi avgpool bug fix #15599

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,37 @@ def avgpool2d(expr, type_map):
t = type_map[arg]
out_t = type_map[expr]

# dq > nn.avg_pool2d > q
# Use the same input quantization parameters for output if the pattern is not the above.
# Type_map is a map of graphs and their Tensoraffinetypes
# Find the current "nn.avg_pool2d" op after checking for the "qnn.quantize" op in the graph.
# Structure for .. dq > op > q will be q [op [dq ..
def check(y, expr):
if isinstance(y, type(expr)):
if y.op.name != "nn.avg_pool2d":
return True
# check if this is the expr avg_pool
if y.attrs != expr.attrs:
return True
return False

for x in type_map.items():
if isinstance(x[0], type(expr)):
if x[0].op.name == "qnn.quantize":
prev = x[0]
y = prev.args[0]
while check(y, expr):
prev = y
y = prev.args[0]
if (
isinstance(y, type(expr))
and y.op.name == "nn.avg_pool2d"
and y.attrs == expr.attrs
):
if prev.op.name != "qnn.quantize":
out_t = t
break

out = relay.qnn.op.avg_pool2d(
arg,
t.scale,
Expand Down
22 changes: 10 additions & 12 deletions python/tvm/topi/hexagon/qnn/avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def saturate(x: te.Tensor, dtype: str):
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))


def get_temp_dtype(h, w, dtype):
temp_dtype = "int16" if h * w < 256 else "int32"
if dtype in ("uint8", "int8"):
return temp_dtype
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")


def qnn_avg_pool2d_NCHW(
data: te.Tensor,
kernel: list,
Expand All @@ -59,12 +67,7 @@ def qnn_avg_pool2d_NCHW(
rh = te.reduce_axis((0, kh), name="rh")
rw = te.reduce_axis((0, kw), name="rw")

if odtype == "uint8":
temp_dtype = "uint16"
elif odtype == "int8":
temp_dtype = "int16"
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")
temp_dtype = get_temp_dtype(kh, kw, odtype)

sh, sw = stride
dh, dw = dilation
Expand Down Expand Up @@ -155,12 +158,7 @@ def qnn_avg_pool2d_NHWC(
rh = te.reduce_axis((0, kh), name="rh")
rw = te.reduce_axis((0, kw), name="rw")

if odtype == "uint8":
temp_dtype = "uint16"
elif odtype == "int8":
temp_dtype = "int16"
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")
temp_dtype = get_temp_dtype(kh, kw, odtype)

sh, sw = stride
dh, dw = dilation
Expand Down
22 changes: 22 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy
import tvm
from tvm import te
from tvm.relay.backend import Executor


def ceildiv(o, d):
Expand Down Expand Up @@ -112,6 +113,27 @@ def build_and_run(inputs, func, target: str, target_host: str, *args, **kwargs):
return tensors[-1].asnumpy()


def build_module(relay_mod, target):
"""builds a relay module for a specified target"""
params = {}
executor = Executor("aot", {"link-params": True})
lowered = tvm.relay.build(
relay_mod,
tvm.target.Target(target, host=target),
executor=executor,
params=params,
)
return lowered


def run_module(mod, inputs):
"""invokes run function of specified module with inputs provided"""
mod.set_input(**inputs)
mod.run()
output = mod.get_output(0).numpy()
return output


def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, out_channels):
assert len(shape_nhwc) == 4
kernel = []
Expand Down
Loading