Skip to content

Commit

Permalink
Revert
Browse files Browse the repository at this point in the history
  • Loading branch information
trahman-quic1 committed Sep 7, 2022
1 parent 94e808c commit 9527f92
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 101 deletions.
43 changes: 4 additions & 39 deletions python/tvm/topi/hexagon/resize2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,59 +58,24 @@ def resize2d_compute(
)


def tir_resize2d_schedule(
def tir_broadcast_schedule(
out_m,
input_a,
input_layout: str,
output_layout: str,
):
"""Schedule for input and output layout nhwc-8h2w32c2w-2d and nhwc-8h8w32c-2d"""
"""Schedule for input and output layout nhwc-8h2w32c2w-2d"""
func = te.create_prim_func([input_a, out_m])

s = tir.Schedule(func)

block = s.get_block("resize")

if input_layout in (
"nhwc-8h2w32c2w-2d",
"nhwc-8h8w32c-2d",
):
if input_layout == "nhwc-8h2w32c2w-2d":
input_transformed_layout = get_layout_transform_fn(input_layout)
s.transform_layout(block, buffer=("read", 0), index_map=input_transformed_layout)

output_transformed_layout = get_layout_transform_fn(output_layout)
s.transform_layout(block, buffer=("write", 0), index_map=output_transformed_layout)

if output_layout == "nhwc-8h2w32c2w-2d":
# Fixed chunk size is 2048 byte
# For fp16 the layout for fixed chunk is 8x4x32
# where each element is 2 bytes
# Split and reorder is done to iterate over the fixed chunk
# Channel is split by a factor of 32
# Width is split by a factor of 4
# Height is split by a factor of 8
n, h, w, c = s.get_loops(block)

ho, hi = s.split(h, [None, 8])
wo, wi = s.split(w, [None, 4])
co, ci = s.split(c, [None, 32])

s.reorder(n, ho, wo, co, hi, wi, ci)

elif output_layout == "nhwc-8h8w32c-2d":
# Fixed chunk size is 2048 byte
# For uint8 the layout for fixed chunk is 8x8x32
# where each element is 1 bytes
# Split and reorder is done to iterate over the fixed chunk
# Channel is split by a factor of 32
# Width is split by a factor of 8
# Height is split by a factor of 8
n, h, w, c = s.get_loops(block)

ho, hi = s.split(h, [None, 8])
wo, wi = s.split(w, [None, 8])
co, ci = s.split(c, [None, 32])

s.reorder(n, ho, wo, co, hi, wi, ci)

return s
return s
11 changes: 2 additions & 9 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,13 +619,6 @@ def _resize_2d(
The computed result with type out_dtype
"""

def saturate(x, dtype):
if dtype == "uint8":
return te.max(0, te.min(x, 255))
elif dtype == "int8":
return te.max(-127, te.min(x, 128))
return x

def _cast_output(value, data_dtype="float32", out_dtype=None):
if out_dtype:
dtype = out_dtype
Expand Down Expand Up @@ -783,7 +776,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
extrapolation_value,
tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out),
)
return _cast_output(saturate(value, out_dtype), data.dtype, out_dtype=out_dtype)
return _cast_output(value, data.dtype, out_dtype=out_dtype)


def resize2d(
Expand Down Expand Up @@ -1373,4 +1366,4 @@ def compute_func(*indices):
out_dtype=out_dtype,
)

return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
13 changes: 2 additions & 11 deletions python/tvm/topi/testing/resize_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ def get_index(x, image_width, target_width, coordinate_transformation_mode):
return out


def saturate(x, dtype):
"""Saturate value for the specified data type"""
if dtype == "uint8":
return np.maximum(0, np.minimum(x, 255))
elif dtype == "int8":
return np.maximum(-127, np.minimum(x, 128))
return x


def resize3d_nearest(arr, scale, coordinate_transformation_mode):
"""Populate the array by scale factor"""
d, h, w = arr.shape
Expand Down Expand Up @@ -169,7 +160,7 @@ def _get_patch(zint, yint, xint):

l = np.sum(p * wx, axis=-1)
col = np.sum(l * wy, axis=-1)
data_out[m, j, k] = saturate(np.sum(col * wz), data_in.dtype).astype(data_in.dtype)
data_out[m, j, k] = np.sum(col * wz)

return data_out

Expand Down Expand Up @@ -282,4 +273,4 @@ def resize3d_python(
if layout == "NDHWC":
output_np = output_np.transpose([0, 2, 3, 4, 1])

return output_np
return output_np
52 changes: 10 additions & 42 deletions tests/python/contrib/test_hexagon/topi/test_resize2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,26 @@

@tvm.testing.fixture
def expected_output_np(
input_np,
in_height,
in_width,
out_height,
out_width,
layout,
method,
coord_trans,
dtype,
input_np, in_height, in_width, out_height, out_width, layout, method, coord_trans
):
scale_h = out_height / in_height
scale_w = out_width / in_width

return resize2d_python(input_np, (scale_h, scale_w), layout, method, coord_trans)


@tvm.testing.fixture
def input_np(input_shape, dtype):
if dtype == "float16":
return np.random.random(input_shape).astype(dtype)
elif dtype == "uint8":
return np.random.randint(0, 255, input_shape).astype(dtype)
elif dtype == "int8":
return np.random.randint(-128, 127, input_shape).astype(dtype)
return np.random.random(input_shape).astype(dtype)


@tvm.testing.fixture
def transformed_input_np(input_np, layout, input_crouton_layout, dtype):
if dtype == "float16" or dtype == "uint8" or dtype == "int8":
return transform_numpy(input_np, layout.lower(), input_crouton_layout)
else:
raise RuntimeError(f"Unsupported data type '{dtype}'")
def transformed_input_np(input_np, layout, input_crouton_layout):
return transform_numpy(input_np, layout.lower(), input_crouton_layout)


@tvm.testing.fixture
def transformed_expected_output_np(expected_output_np, layout, output_layout, dtype):
if dtype == "float16" or dtype == "uint8" or dtype == "int8":
return transform_numpy(expected_output_np, layout.lower(), output_layout)
else:
raise RuntimeError(f"Unsupported data type '{dtype}'")
def transformed_expected_output_np(expected_output_np, layout, output_layout):
return transform_numpy(expected_output_np, layout.lower(), output_layout)


@tvm.testing.fixture
Expand Down Expand Up @@ -100,7 +80,6 @@ class TestResize2d:

(layout, input_crouton_layout, output_layout, dtype,) = tvm.testing.parameters(
("NHWC", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
("NHWC", "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
)

coord_trans = tvm.testing.parameter("asymmetric", "align_corners", "half_pixel")
Expand Down Expand Up @@ -133,18 +112,14 @@ def test_resize2d(
layout=layout,
coordinate_transformation_mode=coord_trans,
method=method,
out_dtype=dtype,
)

tir_schedule = s1.tir_resize2d_schedule(M, A, input_crouton_layout, output_layout)
tir_schedule = s1.tir_broadcast_schedule(M, A, input_crouton_layout, output_layout)

sch = tir_schedule.mod

input_axis_separator = [4]
if output_layout in (
"nhwc-8h2w32c2w-2d",
"nhwc-8h8w32c-2d",
):
if output_layout == "nhwc-8h2w32c2w-2d":
output_axis_separator = [4]
else:
raise RuntimeError(f"Unexpected layout '{output_layout}'")
Expand Down Expand Up @@ -180,16 +155,9 @@ def test_resize2d(
# convert nd to np and reshape to fixed chunk size layout
if output_layout == "nhwc-8h2w32c2w-2d":
M_data_np = M_data_nd.numpy().reshape([b, h // 8, w // 4, c // 32, 8, 2, 32, 2])
elif output_layout == "nhwc-8h8w32c-2d":
M_data_np = M_data_nd.numpy().reshape([b, h // 8, w // 8, c // 32, 8, 8, 32])

if dtype == "float16":
np.testing.assert_allclose(
transformed_expected_output_np, M_data_np, rtol=1e-3, atol=1e-3
)
elif dtype == "int8" or dtype == "uint8":
np.testing.assert_allclose(transformed_expected_output_np, M_data_np, rtol=1, atol=1)
np.testing.assert_allclose(transformed_expected_output_np, M_data_np, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
tvm.testing.main()
tvm.testing.main()

0 comments on commit 9527f92

Please sign in to comment.