Skip to content

Commit

Permalink
simplify gen_conv2d.py
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 1, 2021
1 parent f3b7e13 commit 61040bf
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,33 @@ def create_conv2d_operator(
for iterator_algorithm in iterator_algorithms:
op_entry = {}

op = Conv2dOperation(
ConvKind.Fprop,
iterator_algorithm,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
EpilogueFunctor.LinearCombination,
swizzling_functor_,
)

# TODO(masahi): Add profiler source here
op_entry["opdef"] = kernel_emitter.emit(op)
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["runtime"] = 9999999

# fused ops
for epilogue, opdef in zip(
[
EpilogueFunctor.LinearCombination,
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
],
["opdef", "opdef_bias", "opdef_bias_relu"],
["opdef_bias", "opdef_bias_relu"],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand All @@ -82,11 +102,6 @@ def create_conv2d_operator(

op_entry[opdef] = kernel_emitter.emit(op)

if epilogue == EpilogueFunctor.LinearCombination:
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["runtime"] = 9999999

ret.append(op_entry)

return ret
Expand Down

0 comments on commit 61040bf

Please sign in to comment.