Skip to content

Commit

Permalink
remove 'kernel_name'
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Dec 21, 2024
1 parent e82f873 commit 5710fd1
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 16 deletions.
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def extract_kernels(funcs):
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")


def make_do_bench_for_autotune(kernel_name: str):
def make_do_bench_for_autotune():

def autotuner_do_bench(*args, **kwargs):
return do_bench(*args, n_warmup=10, n_repeat=10, kernel_name=kernel_name, **kwargs)
return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs)

return autotuner_do_bench

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
for w in [8, 16, 32] \
]

tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'],
do_bench=benchmark_suit.make_do_bench_for_autotune('_attn_fwd'))
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune())
tune_attn_fwd = tuner(_attn_fwd)


Expand Down
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def naive_softmax(x):
triton.Config({"threads_per_warp": 16}, num_warps=4),
],
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name="softmax_kernel"),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
num_stages=s, num_warps=32) for s in [2, 3]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -113,7 +113,7 @@ def matmul_kernel_with_block_pointers(
num_stages=s, num_warps=4) for s in [2]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -107,7 +107,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def gelu(x):
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -120,7 +120,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -105,7 +105,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
num_stages=4, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='_kernel'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def _kernel(A, B, C, #
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def mac_loop(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='first_wave'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def first_wave(
Expand Down Expand Up @@ -141,7 +141,7 @@ def first_wave(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='full_tiles'),
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def full_tiles(
Expand Down

0 comments on commit 5710fd1

Please sign in to comment.