From 90e36c72e8cbe47e336bb6c173eaee8fe2cda141 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Mon, 2 Jan 2023 12:57:42 +0200 Subject: [PATCH] Expose mem_scope from generic conv2d variants to be more reusable --- python/tvm/topi/generic/conv2d.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 76cd9a7d69d1..a4a37247c82e 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -132,6 +132,7 @@ def schedule_conv_NCHWc_cpu_common_int8( int8_elems=4, intrin=None, inline_fused=True, + mem_scope="global", ): """ Defines the schedule for INT8 for Intel and ARM machines @@ -186,7 +187,7 @@ def schedule_conv_NCHWc_cpu_common_int8( # schedule 5-D NCHW[x]c conv C, O = conv_out, last - CC = s.cache_write(C, "global") + CC = s.cache_write(C, mem_scope) batch, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=reg_n) @@ -279,6 +280,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8( int8_elems=4, intrin=None, inline_fused=False, + mem_scope="global", ): """ Defines the 1x1 conv schedule for INT8 for Intel and ARM machines @@ -323,7 +325,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8( s[kernel_vec].parallel(parallel_axis) C, O = conv_out, last - CC = s.cache_write(C, "global") + CC = s.cache_write(C, mem_scope) batch, oc_chunk, oh, ow, oc_block = s[C].op.axis oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)