-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start chipping away at Triton version of PaLM, use causal numerically…
… stable softmax (no need for causal mask) + bias-less layernorm, modified from Phil Tillets layernorm tutorial, cite Triton
- Loading branch information
1 parent
01c5abd
commit 6834a28
Showing
7 changed files
with
476 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from palm_pytorch.triton.triton_palm import PaLM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# taken from Phil Tillet's layernorm tutorial for Triton | ||
|
||
# Triton - https://triton-lang.org | ||
# Layernorm tutorial - https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py | ||
# modified to be bias-less | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused(X, Y, W, M, V, stride, N, | ||
BLOCK_SIZE: tl.constexpr): | ||
|
||
row = tl.program_id(0) | ||
cols = tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
|
||
X += row * stride | ||
Y += row * stride | ||
|
||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) | ||
|
||
mean = tl.sum(x, axis=0) / N | ||
|
||
xmean = tl.where(mask, x - mean, 0.) | ||
var = tl.sum(xmean * xmean, axis=0) / N | ||
rstd = 1 / tl.sqrt(var + 1e-5) | ||
xhat = xmean * rstd | ||
|
||
tl.store(M + row, mean) | ||
tl.store(V + row, rstd) | ||
|
||
w = tl.load(W + cols, mask=mask) | ||
y = xhat * w | ||
|
||
tl.store(Y + cols, y, mask=mask) | ||
|
||
|
||
@triton.jit | ||
def _layer_norm_bwd_dx_fused(DX, DY, DW, X, W, M, V, Lock, stride, N, | ||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): | ||
|
||
row = tl.program_id(0) | ||
cols = tl.arange(0, BLOCK_SIZE_N) | ||
mask = cols < N | ||
|
||
X += row * stride | ||
DY += row * stride | ||
DX += row * stride | ||
|
||
lock_id = row % GROUP_SIZE_M | ||
Lock += lock_id | ||
Count = Lock + GROUP_SIZE_M | ||
DW = DW + lock_id * N + cols | ||
|
||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) | ||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) | ||
w = tl.load(W + cols, mask=mask).to(tl.float32) | ||
mean = tl.load(M + row) | ||
rstd = tl.load(V + row) | ||
|
||
xhat = (x - mean) * rstd | ||
wdy = w * dy | ||
xhat = tl.where(mask, xhat, 0.) | ||
wdy = tl.where(mask, wdy, 0.) | ||
mean1 = tl.sum(xhat * wdy, axis=0) / N | ||
mean2 = tl.sum(wdy, axis=0) / N | ||
dx = (wdy - (xhat * mean1 + mean2)) * rstd | ||
|
||
tl.store(DX + cols, dx, mask=mask) | ||
|
||
partial_dw = (dy * xhat).to(w.dtype) | ||
|
||
while tl.atomic_cas(Lock, 0, 1) == 1: | ||
pass | ||
count = tl.load(Count) | ||
|
||
if count == 0: | ||
tl.atomic_xchg(Count, 1) | ||
else: | ||
partial_dw += tl.load(DW, mask=mask) | ||
|
||
tl.store(DW, partial_dw, mask=mask) | ||
|
||
tl.atomic_xchg(Lock, 0) | ||
|
||
@triton.jit | ||
def _layer_norm_bwd_dw(DW, FINAL_DW, M, N, | ||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): | ||
pid = tl.program_id(0) | ||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | ||
|
||
for i in range(0, M, BLOCK_SIZE_M): | ||
rows = i + tl.arange(0, BLOCK_SIZE_M) | ||
mask = (rows[:, None] < M) & (cols[None, :] < N) | ||
offs = rows[:, None] * N + cols[None, :] | ||
dw += tl.load(DW + offs, mask=mask, other=0.) | ||
|
||
sum_dw = tl.sum(dw, axis=0) | ||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) | ||
|
||
|
||
class LayerNorm(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward(ctx, x, normalized_shape, weight): | ||
y = torch.empty_like(x) | ||
|
||
x_arg = x.reshape(-1, x.shape[-1]) | ||
M, N = x_arg.shape | ||
mean = torch.empty((M, ), dtype=torch.float32, device='cuda') | ||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda') | ||
|
||
MAX_FUSED_SIZE = 65536 // x.element_size() | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | ||
if N > BLOCK_SIZE: | ||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | ||
|
||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) | ||
|
||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, mean, rstd, | ||
x_arg.stride(0), N, | ||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) | ||
ctx.save_for_backward(x, weight, mean, rstd) | ||
ctx.BLOCK_SIZE = BLOCK_SIZE | ||
ctx.num_warps = num_warps | ||
return y | ||
|
||
@staticmethod | ||
def backward(ctx, dy): | ||
x, w, m, v = ctx.saved_tensors | ||
|
||
N = w.shape[0] | ||
GROUP_SIZE_M = 64 | ||
if N <= 8192: GROUP_SIZE_M = 96 | ||
if N <= 4096: GROUP_SIZE_M = 128 | ||
if N <= 1024: GROUP_SIZE_M = 256 | ||
|
||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') | ||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) | ||
|
||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) | ||
dx = torch.empty_like(dy) | ||
|
||
x_arg = x.reshape(-1, x.shape[-1]) | ||
M, N = x_arg.shape | ||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, x, w, m, v, locks, | ||
x_arg.stride(0), N, | ||
BLOCK_SIZE_N=ctx.BLOCK_SIZE, | ||
GROUP_SIZE_M=GROUP_SIZE_M, | ||
num_warps=ctx.num_warps) | ||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] | ||
|
||
_layer_norm_bwd_dw[grid](_dw, dw, GROUP_SIZE_M, N, | ||
BLOCK_SIZE_M=32, | ||
BLOCK_SIZE_N=128) | ||
return dx, None, dw, None | ||
|
||
layernorm_without_bias = LayerNorm.apply |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from torch import autograd | ||
import torch.nn.functional as F | ||
|
||
import triton | ||
import triton.language as tl | ||
from triton_transformer.utils import calc_num_warps | ||
|
||
@triton.jit | ||
def softmax_kernel_forward( | ||
output_ptr, | ||
input_ptr, | ||
input_row_stride, | ||
output_row_stride, | ||
n_cols, | ||
BLOCK_SIZE: tl.constexpr | ||
): | ||
row_idx = tl.program_id(0) | ||
|
||
row_start_ptr = input_ptr + row_idx * input_row_stride | ||
|
||
col_offsets = tl.arange(0, BLOCK_SIZE) | ||
input_ptrs = row_start_ptr + col_offsets | ||
|
||
mask = col_offsets < n_cols | ||
|
||
row = tl.load(input_ptrs, mask = mask, other = -float('inf')) | ||
|
||
causal_mask = col_offsets > (row_idx % n_cols) | ||
row = row + tl.where(causal_mask, -float('inf'), 0.) | ||
|
||
row_minus_max = row - tl.max(row, axis=0) | ||
|
||
numerator = tl.exp(row_minus_max) | ||
denominator = tl.sum(numerator, axis=0) | ||
softmax_output = numerator / denominator | ||
|
||
output_row_start_ptr = output_ptr + row_idx * output_row_stride | ||
output_ptrs = output_row_start_ptr + col_offsets | ||
tl.store(output_ptrs, softmax_output, mask = mask) | ||
|
||
@triton.jit | ||
def softmax_kernel_backward( | ||
output_ptr, | ||
input_ptr, | ||
grad_ptr, | ||
grad_row_stride, | ||
input_row_stride, | ||
output_row_stride, | ||
n_cols, | ||
BLOCK_SIZE: tl.constexpr | ||
): | ||
row_idx = tl.program_id(0) | ||
|
||
row_start_ptr = input_ptr + row_idx * input_row_stride | ||
grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride | ||
|
||
col_offsets = tl.arange(0, BLOCK_SIZE) | ||
input_ptrs = row_start_ptr + col_offsets | ||
grad_ptrs = grad_row_start_ptr + col_offsets | ||
|
||
mask = col_offsets < n_cols | ||
|
||
probs_row = tl.load(input_ptrs, mask = mask, other = 0.) | ||
grad_row = tl.load(grad_ptrs, mask = mask, other = 0.) | ||
|
||
dxhat = probs_row * grad_row | ||
softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0) | ||
|
||
output_row_start_ptr = output_ptr + row_idx * output_row_stride | ||
output_ptrs = output_row_start_ptr + col_offsets | ||
tl.store(output_ptrs, softmax_grad_output, mask = mask) | ||
|
||
class _softmax(autograd.Function): | ||
@classmethod | ||
def forward(self, ctx, x): | ||
shape = x.shape | ||
x = x.view(-1, shape[-1]) | ||
n_rows, n_cols = x.shape | ||
|
||
BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
num_warps = calc_num_warps(BLOCK_SIZE) | ||
|
||
y = torch.empty_like(x) | ||
|
||
softmax_kernel_forward[(n_rows,)]( | ||
y, | ||
x, | ||
x.stride(0), | ||
y.stride(0), | ||
n_cols, | ||
num_warps = num_warps, | ||
BLOCK_SIZE = BLOCK_SIZE, | ||
) | ||
|
||
if x.requires_grad: | ||
ctx.save_for_backward(y) | ||
return y.view(*shape) | ||
|
||
@classmethod | ||
def backward(self, ctx, grad_probs): | ||
shape = grad_probs.shape | ||
probs, = ctx.saved_tensors | ||
|
||
grad_probs = grad_probs.view(-1, grad_probs.shape[-1]) | ||
n_rows, n_cols = grad_probs.shape | ||
|
||
BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
num_warps = calc_num_warps(BLOCK_SIZE) | ||
|
||
dx = torch.empty_like(probs) | ||
|
||
softmax_kernel_backward[(n_rows,)]( | ||
dx, | ||
probs, | ||
grad_probs, | ||
grad_probs.stride(0), | ||
probs.stride(0), | ||
dx.stride(0), | ||
n_cols, | ||
num_warps = num_warps, | ||
BLOCK_SIZE = BLOCK_SIZE | ||
) | ||
|
||
return dx.view(*shape), None | ||
|
||
causal_softmax = _softmax.apply |
Oops, something went wrong.