Skip to content

Commit

Permalink
start chipping away at Triton version of PaLM, use causal numerically…
Browse files Browse the repository at this point in the history
… stable softmax (no need for causal mask) + bias-less layernorm, modified from Phil Tillets layernorm tutorial, cite Triton
  • Loading branch information
lucidrains committed Apr 5, 2022
1 parent 01c5abd commit 6834a28
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 9 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ palm = PaLM(
$ python train.py
```

## Todo

- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer

## Citations

```bibtex
@article{chowdhery2022PaLM,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Chowdhery, Aakanksha et al},
year = {2022}
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Chowdhery, Aakanksha et al},
year = {2022}
}
```

```bibtex
@article{Tillet2019TritonAI,
title = {Triton: an intermediate language and compiler for tiled neural network computations},
author = {Philippe Tillet and H. T. Kung and David D. Cox},
journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
year = {2019}
}
```
1 change: 1 addition & 0 deletions palm_pytorch/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from palm_pytorch.triton.triton_palm import PaLM
161 changes: 161 additions & 0 deletions palm_pytorch/triton/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# taken from Phil Tillet's layernorm tutorial for Triton

# Triton - https://triton-lang.org
# Layernorm tutorial - https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py
# modified to be bias-less

import torch
import triton
import triton.language as tl

@triton.jit
def _layer_norm_fwd_fused(X, Y, W, M, V, stride, N,
BLOCK_SIZE: tl.constexpr):

row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N

X += row * stride
Y += row * stride

x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)

mean = tl.sum(x, axis=0) / N

xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + 1e-5)
xhat = xmean * rstd

tl.store(M + row, mean)
tl.store(V + row, rstd)

w = tl.load(W + cols, mask=mask)
y = xhat * w

tl.store(Y + cols, y, mask=mask)


@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, X, W, M, V, Lock, stride, N,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):

row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N

X += row * stride
DY += row * stride
DX += row * stride

lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols

x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)

xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd

tl.store(DX + cols, dx, mask=mask)

partial_dw = (dy * xhat).to(w.dtype)

while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)

if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)

tl.store(DW, partial_dw, mask=mask)

tl.atomic_xchg(Lock, 0)

@triton.jit
def _layer_norm_bwd_dw(DW, FINAL_DW, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)

sum_dw = tl.sum(dw, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)


class LayerNorm(torch.autograd.Function):

@staticmethod
def forward(ctx, x, normalized_shape, weight):
y = torch.empty_like(x)

x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')

MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

num_warps = min(max(BLOCK_SIZE // 256, 1), 8)

_layer_norm_fwd_fused[(M,)](x_arg, y, weight, mean, rstd,
x_arg.stride(0), N,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
return y

@staticmethod
def backward(ctx, dy):
x, w, m, v = ctx.saved_tensors

N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256

locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)

dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)

x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, x, w, m, v, locks,
x_arg.stride(0), N,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]

_layer_norm_bwd_dw[grid](_dw, dw, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, None

layernorm_without_bias = LayerNorm.apply
127 changes: 127 additions & 0 deletions palm_pytorch/triton/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch import autograd
import torch.nn.functional as F

import triton
import triton.language as tl
from triton_transformer.utils import calc_num_warps

@triton.jit
def softmax_kernel_forward(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)

row_start_ptr = input_ptr + row_idx * input_row_stride

col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets

mask = col_offsets < n_cols

row = tl.load(input_ptrs, mask = mask, other = -float('inf'))

causal_mask = col_offsets > (row_idx % n_cols)
row = row + tl.where(causal_mask, -float('inf'), 0.)

row_minus_max = row - tl.max(row, axis=0)

numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator

output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask = mask)

@triton.jit
def softmax_kernel_backward(
output_ptr,
input_ptr,
grad_ptr,
grad_row_stride,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)

row_start_ptr = input_ptr + row_idx * input_row_stride
grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride

col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
grad_ptrs = grad_row_start_ptr + col_offsets

mask = col_offsets < n_cols

probs_row = tl.load(input_ptrs, mask = mask, other = 0.)
grad_row = tl.load(grad_ptrs, mask = mask, other = 0.)

dxhat = probs_row * grad_row
softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0)

output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_grad_output, mask = mask)

class _softmax(autograd.Function):
@classmethod
def forward(self, ctx, x):
shape = x.shape
x = x.view(-1, shape[-1])
n_rows, n_cols = x.shape

BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)

y = torch.empty_like(x)

softmax_kernel_forward[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE,
)

if x.requires_grad:
ctx.save_for_backward(y)
return y.view(*shape)

@classmethod
def backward(self, ctx, grad_probs):
shape = grad_probs.shape
probs, = ctx.saved_tensors

grad_probs = grad_probs.view(-1, grad_probs.shape[-1])
n_rows, n_cols = grad_probs.shape

BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)

dx = torch.empty_like(probs)

softmax_kernel_backward[(n_rows,)](
dx,
probs,
grad_probs,
grad_probs.stride(0),
probs.stride(0),
dx.stride(0),
n_cols,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
)

return dx.view(*shape), None

causal_softmax = _softmax.apply
Loading

0 comments on commit 6834a28

Please sign in to comment.