Skip to content

Commit

Permalink
[TOPI] Expose mem_scope from generic conv2d variants to be more reusa…
Browse files Browse the repository at this point in the history
…ble (apache#13680)

Expose mem_scope from generic conv2d variants to be more reusable
  • Loading branch information
cbalint13 authored and fzi-peccia committed Mar 27, 2023
1 parent f121fd7 commit aa0699c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/tvm/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
int8_elems=4,
intrin=None,
inline_fused=True,
mem_scope="global",
):
"""
Defines the schedule for INT8 for Intel and ARM machines
Expand Down Expand Up @@ -186,7 +187,7 @@ def schedule_conv_NCHWc_cpu_common_int8(

# schedule 5-D NCHW[x]c conv
C, O = conv_out, last
CC = s.cache_write(C, "global")
CC = s.cache_write(C, mem_scope)

batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
Expand Down Expand Up @@ -279,6 +280,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
int8_elems=4,
intrin=None,
inline_fused=False,
mem_scope="global",
):
"""
Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
Expand Down Expand Up @@ -323,7 +325,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
s[kernel_vec].parallel(parallel_axis)

C, O = conv_out, last
CC = s.cache_write(C, "global")
CC = s.cache_write(C, mem_scope)

batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
Expand Down

0 comments on commit aa0699c

Please sign in to comment.