Skip to content

Commit

Permalink
[TIR][USMP] adding the pass to convert to pool offsets
Browse files Browse the repository at this point in the history
Fixing the references after changes in the memory planning
algorithm.

Change-Id: Id7c22356fd5de43d10a2b4fc70e978af2c6d599d
  • Loading branch information
manupak committed Dec 1, 2021
1 parent d66cd2c commit e5f1867
Showing 1 changed file with 53 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde
# fmt: on


def test_linear():
def test_mobilenet_subgraph():
target = Target("c")
fast_memory_pool = usmp_utils.PoolInfo(
pool_name="fast_memory",
Expand Down Expand Up @@ -231,6 +231,7 @@ def test_linear():
)(tir_mod)

tir_mod_with_offsets_ref = LinearStructurePlanned
tir_mod_with_offsets_ref = tvm.script.from_source(tir_mod_with_offsets_ref.script(show_meta=False))
# The TIR produced fails on roundtrip TVMScript testing.
# Therefore, indicates the TVMScript produced here and/or the parser
# is lacking functionality. Thus for these tests, uses a string
Expand Down Expand Up @@ -365,40 +366,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
@tvm.script.ir_module
class ResnetStructurePlanned:
@T.prim_func
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None:
placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 0), dtype="handle")
for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 1478912), dtype="handle")
for ff in T.serial(0, 64):
T.store(Conv2dOutput_let, ff, 0, True)
for rc in T.serial(0, 64):
T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
for ax3_inner_1 in T.serial(0, 64):
T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)

@T.prim_func
def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None:
global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None:
placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle")
sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle")
sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle")
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32"))
for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)

@T.prim_func
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None:
Expand All @@ -407,13 +382,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 5760000), dtype="handle")
PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle")
for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle")
Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle")
for ax3_outer_2 in T.serial(0, 4):
for ff_3 in T.serial(0, 64):
T.store(Conv2dOutput_3_let, ff_3, 0, True)
Expand All @@ -428,13 +403,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 5760000), dtype="handle")
PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle")
for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 6480000), dtype="handle")
Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle")
for ax3_outer_1 in T.serial(0, 4):
for ff_2 in T.serial(0, 64):
T.store(Conv2dOutput_2_let, ff_2, 0, True)
Expand All @@ -444,38 +419,65 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)

@T.prim_func
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None:
placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None:
placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle")
for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle")
for ff in T.serial(0, 64):
T.store(Conv2dOutput_let, ff, 0, True)
for rc in T.serial(0, 64):
T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
for ax3_inner_1 in T.serial(0, 64):
T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)

@T.prim_func
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None:
placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle")
for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 1478912), dtype="handle")
Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle")
for ff_1 in T.serial(0, 64):
T.store(Conv2dOutput_1_let, ff_1, 0, True)
for ry, rx, rc_1 in T.grid(3, 3, 64):
T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
for ax3_inner_2 in T.serial(0, 64):
T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)

@T.prim_func
def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None:
global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle")
sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle")
sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32"))
__tvm_meta__ = None
# fmt: on


def test_fanout():
def test_resnet_subgraph():
target = Target("c")
global_workspace_pool = usmp_utils.PoolInfo(
pool_name="global_workspace",
Expand All @@ -498,6 +500,7 @@ def test_fanout():
)(tir_mod)

tir_mod_with_offsets_ref = ResnetStructurePlanned

# The TIR produced fails on roundtrip TVMScript testing.
# Therefore, indicates the TVMScript produced here and/or the parser
# is lacking functionality. Thus for these tests, uses a string
Expand Down

0 comments on commit e5f1867

Please sign in to comment.