Skip to content

Commit

Permalink
pass a function for autotuner via 'do_bench' param
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Dec 18, 2024
1 parent 5878f0d commit 1a1c98e
Show file tree
Hide file tree
Showing 11 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401
from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401

if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
from triton.runtime import driver
Expand Down
9 changes: 7 additions & 2 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
import os
from typing import Any, Dict, List
import triton

USE_IPEX_OPTION = os.getenv("USE_IPEX", "1") == "1"
if USE_IPEX_OPTION:
Expand Down Expand Up @@ -237,7 +236,13 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
else:
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")

triton.testing.do_bench = do_bench

def make_do_bench_for_autotune(kernel_name: str):

def autotuner_do_bench(*args, **kwargs):
return do_bench(*args, n_warmup=10, n_repeat=10, kernel_name=kernel_name, **kwargs)

return autotuner_do_bench


def assert_close(x, y, atol=None, rtol=None, err_msg=""):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
for w in [8, 16, 32] \
]

tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'],
do_bench=benchmark_suit.make_do_bench_for_autotune('_attn_fwd'))
tune_attn_fwd = tuner(_attn_fwd)


Expand Down
1 change: 1 addition & 0 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def naive_softmax(x):
triton.Config({"threads_per_warp": 16}, num_warps=4),
],
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name="softmax_kernel"),
)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
num_stages=s, num_warps=32) for s in [2, 3]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -116,6 +117,7 @@ def matmul_kernel_with_block_pointers(
num_stages=s, num_warps=4) for s in [2]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -109,6 +110,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def gelu(x):
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -122,6 +123,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -107,6 +108,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
num_stages=4, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='_kernel'),
)
@triton.jit
def _kernel(A, B, C, #
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def mac_loop(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='first_wave'),
)
@triton.jit
def first_wave(
Expand Down Expand Up @@ -143,6 +144,7 @@ def first_wave(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='full_tiles'),
)
@triton.jit
def full_tiles(
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def kernel(x_ptr, x_size, **META):
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
use_cuda_graph=use_cuda_graph)
use_cuda_graph=use_cuda_graph, do_bench=do_bench)

return decorator

Expand Down

0 comments on commit 1a1c98e

Please sign in to comment.