diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index 7f0790aebf4d..cce56b796cea 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -57,6 +57,7 @@ def schedule_injective_from_existing(sch, out): need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False + const_size = 0 if vector_width > 1: fused, v = sch[out].split(fused, vector_width) @@ -72,7 +73,10 @@ def schedule_injective_from_existing(sch, out): # Use less threads for dynamic shape ops to avoid runtime error. if is_dynamic_output: num_thread //= 2 - bx, tx = sch[out].split(fused, factor=num_thread) + if const_size != 0 and const_size < num_thread: + bx, tx = sch[out].split(fused, factor=const_size) + else: + bx, tx = sch[out].split(fused, factor=num_thread) sch[out].bind(tx, te.thread_axis("threadIdx.x")) sch[out].bind(bx, te.thread_axis("blockIdx.x"))