forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rmsnorm kernel #633
Merged
Merged
Add rmsnorm kernel #633
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
import argparse | ||
import torch | ||
import sys | ||
import pytest | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
|
||
def is_cuda(): | ||
return triton.runtime.driver.active.get_current_target().backend == "cuda" | ||
|
||
|
||
def is_hip(): | ||
return triton.runtime.driver.active.get_current_target().backend == "hip" | ||
|
||
|
||
def get_cuda_autotune_config(): | ||
return [ | ||
triton.Config({}, num_warps=4, num_stages=1), | ||
triton.Config({}, num_warps=8, num_stages=1), | ||
triton.Config({}, num_warps=16, num_stages=1), | ||
] | ||
|
||
|
||
def get_hip_autotune_config(): | ||
return [ | ||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), | ||
] | ||
|
||
|
||
def get_autotune_config(): | ||
if is_cuda(): | ||
return get_cuda_autotune_config() | ||
else: | ||
return get_hip_autotune_config() | ||
|
||
|
||
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, | ||
BLOCK_SIZE: tl.constexpr): | ||
row_start = tl.program_id(0) | ||
row_step = tl.num_programs(0) | ||
col_offsets = tl.arange(0, BLOCK_SIZE) | ||
mask = col_offsets < n_cols | ||
for row_idx in tl.range(row_start, n_rows, row_step): | ||
row_start_ptr = input_ptr + row_idx * input_row_stride | ||
input_ptrs = row_start_ptr + col_offsets | ||
input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") | ||
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") | ||
row_norm = row * row #square each value | ||
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1) | ||
row_norm = row_norm / n_cols #divide by n_cols | ||
row_norm = row_norm + epsilon #add epsilon | ||
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value | ||
rms_norm = row * row_norm #multiply each x by normalization value | ||
rms_norm = rms_norm * g #element wise multiplication with g | ||
|
||
output_row_start_ptr = output_ptr + row_idx * output_row_stride | ||
output_ptrs = output_row_start_ptr + col_offsets | ||
output_ptrs = tl.multiple_of(output_ptrs, (16, )) | ||
tl.store(output_ptrs, rms_norm, mask=mask) | ||
|
||
|
||
def rmsnorm(x, epsilon=1e-6): | ||
n_rows, n_cols = x.shape | ||
BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
|
||
y = torch.empty_like(x, device='cuda') | ||
g = torch.ones((1, n_cols), device='cuda') | ||
|
||
num_programs = n_rows | ||
grid = lambda meta: (num_programs, ) | ||
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) | ||
|
||
return y | ||
|
||
|
||
def run_rmsnorm(M, N): | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
y_triton = rmsnorm(x) | ||
|
||
return y_triton | ||
|
||
|
||
@pytest.mark.parametrize('M, N', [ | ||
(1, 4), | ||
(2, 10), | ||
(8192, 4096), | ||
(4096, 8192), | ||
(1, 8192), | ||
(873, 1245), | ||
]) | ||
def test_rmsnorm(M, N): | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
y_triton = rmsnorm(x) | ||
|
||
rms_norm = torch.nn.RMSNorm(N, device='cuda') | ||
y_torch = rms_norm(x) | ||
|
||
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) | ||
|
||
|
||
#Benchmark | ||
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} | ||
|
||
|
||
def torch_rmsnorm(x): | ||
M, N = x.shape | ||
rms_norm = torch.nn.RMSNorm(N, device='cuda') | ||
y_torch = rms_norm(x) | ||
|
||
return y_torch | ||
|
||
|
||
def run_benchmark(args): | ||
config = [] | ||
if (args.M_benchmark): | ||
val = args.M_start | ||
x_vals_list = [] | ||
while val <= args.M_end: | ||
x_vals_list.append(val) | ||
val *= args.M_step | ||
mn_args = {'N': args.N_start} | ||
plot_name = str("rmsnorm-performance_" + args.dtype + "_N" + str(args.N_start) + "_M" + str(args.M_start) + | ||
"-" + str(args.M_end) + "-" + str(args.M_step)) | ||
x_names = ['M'] | ||
else: | ||
x_vals_list = [i for i in range(args.N_start, args.N_end, args.N_step)] | ||
mn_args = {'M': args.M_start} | ||
x_names = ['N'] | ||
plot_name = str("rmsnorm-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) + | ||
"-" + str(args.N_end) + "-" + str(args.N_step)) | ||
|
||
dtype = arg_to_torch_dtype[args.dtype] | ||
|
||
print(plot_name) | ||
config.append( | ||
triton.testing.Benchmark( | ||
x_names=x_names, | ||
x_vals=x_vals_list, | ||
line_arg='provider', | ||
line_vals=['triton', 'torch'], | ||
line_names=["Triton", "Torch"], | ||
styles=[('blue', '-'), ('green', '-')], | ||
ylabel="GB/s", | ||
plot_name=plot_name, | ||
args=mn_args, | ||
)) | ||
|
||
@triton.testing.perf_report(config) | ||
def benchmark(M, N, provider): | ||
x = torch.randn(M, N, device='cuda', dtype=dtype) | ||
stream = torch.cuda.Stream() | ||
torch.cuda.set_stream(stream) | ||
if provider == 'torch': | ||
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x)) | ||
if provider == 'triton': | ||
ms = triton.testing.do_bench(lambda: rmsnorm(x)) | ||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) | ||
return gbps(ms) | ||
|
||
benchmark.run(save_path=".", show_plots=True, print_data=True) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
prog="Benchmark RMSNorm", | ||
allow_abbrev=False, | ||
) | ||
|
||
parser.add_argument('-M', "--M_start", default="1", type=int) | ||
parser.add_argument('-Ms', "--M_step", default="2", type=int) #This is multiplicative step | ||
parser.add_argument('-Me', "--M_end", default="512", type=int) | ||
parser.add_argument('-Mb', "--M_benchmark", default=False, type=bool) | ||
|
||
parser.add_argument('-N', "--N_start", default="8192", type=int) | ||
parser.add_argument('-Ns', "--N_step", default="1024", type=int) | ||
parser.add_argument('-Ne', "--N_end", default="32768", type=int) | ||
|
||
parser.add_argument('-d', "--dtype", default="fp16") | ||
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
if args.no_benchmark: | ||
run_rmsnorm(args.M_start, args.N_start) | ||
else: | ||
run_benchmark(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing wrong here. Just wondering if these configs are comprehensive enough to cover the typical use cases you have encountered in actual models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I just came up with these on what I thought made sense. Other than that, no systematic way to come up with this. I was thinking may be I can add an argument to take in a custom config file. That way one doesn't need to touch the code in this file if some other configs need to be benchmarked.
Actually, I noticed a small bug for CUDA. There are repeated entries.