Skip to content

Commit

Permalink
Slight optimize the default injective schedule (apache#7158)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored Dec 24, 2020
1 parent 68e7838 commit e27ad08
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/tvm/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def schedule_injective_from_existing(sch, out):
need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
const_size = 0

if vector_width > 1:
fused, v = sch[out].split(fused, vector_width)
Expand All @@ -72,7 +73,10 @@ def schedule_injective_from_existing(sch, out):
# Use less threads for dynamic shape ops to avoid runtime error.
if is_dynamic_output:
num_thread //= 2
bx, tx = sch[out].split(fused, factor=num_thread)
if const_size != 0 and const_size < num_thread:
bx, tx = sch[out].split(fused, factor=const_size)
else:
bx, tx = sch[out].split(fused, factor=num_thread)
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
sch[out].bind(bx, te.thread_axis("blockIdx.x"))

Expand Down

0 comments on commit e27ad08

Please sign in to comment.