Skip to content

Commit

Permalink
Implement custom sharding for flash_mha, to allow efficient multi-gpu…
Browse files Browse the repository at this point in the history
… computation when sharded across batch or head dimensions.
  • Loading branch information
nshepperd committed Feb 20, 2024
1 parent 427ae9c commit 5a78a51
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 73 deletions.
83 changes: 25 additions & 58 deletions lame_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# from flash_attn_jax.flash import flash_mha_fwd, flash_mha_bwd
from flash_attn_jax import flash_mha


if __name__ == '__main__':
import time
import numpy as np
Expand Down Expand Up @@ -51,66 +50,34 @@ def pretty(tensor):
def fwd(q,k,v):
return flash_mha(q,k,v)

# print(fwd.lower(q,k,v).as_text())

from jax.sharding import PositionalSharding
from einops import rearrange

sharding = PositionalSharding(jax.devices())
# sharding = PositionalSharding(jax.devices())
devices = jax.devices()
# devices = [*jax.devices(), *jax.devices(backend='cpu')]
n_device = len(devices)
sharding = PositionalSharding(devices).reshape(1,-1,1,1)#.replicate()


# from jax.experimental import mesh_utils
# from jax.sharding import PartitionSpec as P, Mesh
# from jax.sharding import NamedSharding
# devices = np.array(jax.devices()) #mesh_utils.create_device_mesh((1,))
# mesh = Mesh(devices, axis_names=('x',))
# sharding = NamedSharding(mesh, P(None,None,'x',None))

# print(mesh)

o_ref = fwd(q,k,v)

q = jax.device_put(q, sharding.reshape(2,1,1,1))
k = jax.device_put(k, sharding.reshape(2,1,1,1))
v = jax.device_put(v, sharding.reshape(2,1,1,1))
q = jax.device_put(q, sharding)
k = jax.device_put(k, sharding)
v = jax.device_put(v, sharding)
jax.debug.visualize_array_sharding(rearrange(q, 'n l h d -> n (l h d)'))
print(fwd.lower(q,k,v).compile().as_text())
exit()

# print('==== forward ====')
# q = jax.random.normal(jax.random.PRNGKey(0), [32, 4096, 4, 32]).astype(jnp.float16)
# k = jax.random.normal(jax.random.PRNGKey(1), [32, 4096, 4, 32]).astype(jnp.float16)
# v = jax.random.normal(jax.random.PRNGKey(2), [32, 4096, 4, 32]).astype(jnp.float16)

# @jax.jit
# def fwd(q,k,v):
# o = flash_mha(q,k,v)
# for _ in range(32):
# o = flash_mha(q,k,o)
# return o

# @jax.jit
# def fwd_jax(q,k,v):
# ro = pure_mha(q,k,v)
# for _ in range(32):
# ro = pure_mha(q,k,ro)
# return ro

# o = fwd(q,k,v) #, softmax_scale=float(np.sqrt(1/32)))[0]
# start = time.time()
# o = fwd(q,k,v) #, softmax_scale=float(np.sqrt(1/32)))[0]
# print('flash:', time.time() - start, 'seconds')
# ro = fwd_jax(q,k,v)
# start = time.time()
# ro = fwd_jax(q,k,v)
# print('jax:', time.time() - start, 'seconds')
# print(pretty(jnp.abs(o - ro)), jnp.mean(jnp.abs(ro)))

# @jax.jit
# @jax.grad
# def grad_pure(inputs):
# q,k,v = inputs
# return pure_mha(q,k,v).sum()

# @jax.jit
# @jax.grad
# def grad_flash(inputs):
# q,k,v = inputs
# return flash_mha(q,k,v).sum()

# print('==== backward ====')
# q = jax.random.normal(jax.random.PRNGKey(0), [1, 4, 2, 32]).astype(jnp.float16)
# k = jax.random.normal(jax.random.PRNGKey(1), [1, 4, 2, 32]).astype(jnp.float16)
# v = jax.random.normal(jax.random.PRNGKey(2), [1, 4, 2, 32]).astype(jnp.float16)
# dq, dk, dv = grad_flash((q,k,v))
# rdq, rdk, rdv = grad_pure((q,k,v))
# # print(rdq, jnp.mean(jnp.abs(rdq)))
# print('q', pretty(jnp.abs(dq - rdq)), jnp.mean(jnp.abs(rdq)))
# print('k', pretty(jnp.abs(dk - rdk)), jnp.mean(jnp.abs(rdk)))
# print('v', pretty(jnp.abs(dv - rdv)), jnp.mean(jnp.abs(rdv)))
o = fwd(q,k,v)
jax.debug.visualize_array_sharding(rearrange(o, 'n l h d -> n (l h d)'))
print((o - o_ref).std())
192 changes: 177 additions & 15 deletions src/flash_attn_jax/flash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import partial
from functools import partial, wraps

import jax
import jax.numpy as jnp
Expand All @@ -10,13 +10,17 @@
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from jax.experimental.custom_partitioning import custom_partitioning

from einops import rearrange
import math

import flash_attn_jax.flash_api as flash_api

# ==== Register primitives ====

# We do this with two sets of primitives.
# These are the main ones used, and supports any settings or sharding
_flash_mha_fwd_p = core.Primitive("flash_mha_fwd")
_flash_mha_fwd_p.multiple_results = True
_flash_mha_fwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_fwd_p))
Expand All @@ -25,8 +29,29 @@
_flash_mha_bwd_p.multiple_results = True
_flash_mha_bwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_p))


def flash_mha_fwd(q, k, v, softmax_scale, is_causal, window_size):
# The low level 'cuda' primitives are only used for lowering to hlo,
# and requires d to be padded to a multiple of 8, which we add during
# lowering of the main prims above.
_flash_mha_fwd_cuda_p = core.Primitive("flash_mha_fwd_cuda")
_flash_mha_fwd_cuda_p.multiple_results = True
_flash_mha_fwd_cuda_p.def_impl(partial(xla.apply_primitive, _flash_mha_fwd_cuda_p))

_flash_mha_bwd_cuda_p = core.Primitive("flash_mha_bwd_cuda")
_flash_mha_bwd_cuda_p.multiple_results = True
_flash_mha_bwd_cuda_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_cuda_p))

# @partial(partial, partial)
# def trace(name, func):
# @wraps(func)
# def f(*args, **kwargs):
# print(name, args, kwargs)
# return func(*args, **kwargs)
# return f

# ==== Single shard low level frontend ===
# Adds padding before calling into the cuda primitive.

def _flash_mha_fwd_cuda_1(q, k, v, softmax_scale, is_causal, window_size):
d = q.shape[-1]
assert len(q.shape) == 4
assert d == k.shape[-1]
Expand All @@ -37,12 +62,12 @@ def flash_mha_fwd(q, k, v, softmax_scale, is_causal, window_size):
q = jnp.pad(q, padding)
k = jnp.pad(k, padding)
v = jnp.pad(v, padding)
out, lse = _flash_mha_fwd_p.bind(q, k, v, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
out, lse = _flash_mha_fwd_cuda_p.bind(q, k, v, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
if d % 8 != 0:
out = out[..., :d]
return out, lse

def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size):
def _flash_mha_bwd_cuda_1(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size):
d = q.shape[-1]
assert len(q.shape) == 4
assert d == k.shape[-1]
Expand All @@ -55,11 +80,129 @@ def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size
v = jnp.pad(v, padding)
out = jnp.pad(out, padding)
dout = jnp.pad(dout, padding)
dq, dk, dv = _flash_mha_bwd_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
dq, dk, dv = _flash_mha_bwd_cuda_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
if d % 8 != 0:
return dq[...,:d], dk[...,:d], dv[...,:d]
return dq, dk, dv

# ==== Sharding ====

_flash_mha_fwd_cuda = custom_partitioning(_flash_mha_fwd_cuda_1, static_argnums=(3,4,5))
_flash_mha_bwd_cuda = custom_partitioning(_flash_mha_bwd_cuda_1, static_argnums=(6,7,8))

from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding

# @trace("partition_fwd")
def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)

q_sharding = arg_shardings[0]
if isinstance(q_sharding, PositionalSharding):
if not is_causal and window_size == (-1,-1):
# We can handle Q that's sharded across the L dimension
# without replicating Q by executing it as a cross
# attention:
#
# q : n [L/devices] h d
# kv : n L h d
# -> o : n [L/devices] h d
#
# TODO: We could handle q sharded across L even with
# causal/local if we could communicate the slice offset
# (of q in kv) to the c++ driver. But it's unclear how to
# do that since the HLO has to be identical (SPMD).
q_sharding = q_sharding.replicate(3)
kv_sharding = q_sharding.replicate(1)
(n,l,h,d) = q_sharding.shape
result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1) # n h l
arg_shardings = q_sharding, kv_sharding, kv_sharding
else:
# We need to replicate d always.
q_sharding = q_sharding.replicate((1,3))
(n,l,h,d) = q_sharding.shape # l=1, d=1
result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1)
arg_shardings = q_sharding, q_sharding, q_sharding
elif isinstance(q_sharding, NamedSharding):
mesh = q_sharding.mesh
[n,l,h,d] = q_sharding.spec
if not is_causal and window_size == (-1,-1):
q_sharding = NamedSharding(mesh, P(n,l,h,None))
kv_sharding = NamedSharding(mesh, P(n,None,h,None))
lse_sharding = NamedSharding(mesh, P(n,h,l))
else:
q_sharding = NamedSharding(mesh, P(n,None,h,None))
kv_sharding = q_sharding
lse_sharding = NamedSharding(mesh, P(n,h,None))
result_sharding = (q_sharding, lse_sharding)
arg_shardings = (q_sharding, kv_sharding, kv_sharding)
def fwd(q,k,v):
return _flash_mha_fwd_cuda_1(q,k,v, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
return mesh, fwd, result_shardings, arg_shardings

# @trace("infer_sharding_fwd")
def infer_sharding_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
q_sharding = arg_shardings[0]
if isinstance(q_sharding, PositionalSharding):
[n,l,h,d] = q_sharding.shape
# breakpoint()
result_sharding = (q_sharding, # [n,l,h,d]
q_sharding.replicate(3).reshape(n,l,h).transpose((0,2,1)) # [n,h,l]
)
elif isinstance(q_sharding, NamedSharding):
[n,l,h,d] = q_sharding.spec
result_sharding = (q_sharding,
NamedSharding(q_sharding.mesh, P(n,h,l)))
return result_sharding

_flash_mha_fwd_cuda.def_partition(
infer_sharding_from_operands=infer_sharding_fwd,
partition=partition_fwd)

def infer_sharding_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
# args: dout, q, k, v, out, lse
# outs: dq, dk, dv
# i think generally we want the output sharding for dq,dk,dv to be the same as q,k,v?
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
q_sharding = arg_shardings[1]
k_sharding = arg_shardings[2]
v_sharding = arg_shardings[3]
return q_sharding, k_sharding, v_sharding

def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)

do_sharding = arg_shardings[0]
q_sharding = arg_shardings[1]
k_sharding = arg_shardings[2]
v_sharding = arg_shardings[3]
o_sharding = arg_shardings[4]
lse_sharding = arg_shardings[5]
if isinstance(q_sharding, PositionalSharding):
do_sharding = q_sharding.replicate((1,3))
[n, l, h, d] = do_sharding.shape
lse_sharding = do_sharding.reshape(n,l,h).transpose(0,2,1) # n h l
result_shardings = (do_sharding,)*3
arg_shardings = (do_sharding,)*5 + (lse_sharding,)
elif isinstance(q_sharding, NamedSharding):
mesh = q_sharding.mesh
[n,l,h,d] = q_sharding.spec
do_sharding = NamedSharding(mesh, P(n,None,h,None))
lse_sharding = NamedSharding(mesh, P(n,h,None))
result_shardings = (do_sharding,)*3
def fwd(*args):
return _flash_mha_bwd_cuda_1(*args, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
return mesh, fwd, result_shardings, arg_shardings

_flash_mha_bwd_cuda.def_partition(
infer_sharding_from_operands=infer_sharding_bwd,
partition=partition_bwd)

# ==== CUDA lowerings ====

# Register functions defined in gpu_ops as custom call target for GPUs
Expand All @@ -72,7 +215,6 @@ def row_major(shape):
return [row_major(shape) for shape in shapes]

def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None, d_og=None, is_causal=False, window_size=None):
# print(type(q), dir(q), q.type)
q_type = ir.RankedTensorType(q.type)
q_shape = q_type.shape
k_type = ir.RankedTensorType(k.type)
Expand Down Expand Up @@ -115,7 +257,7 @@ def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None, d_og=None, is
return out

mlir.register_lowering(
_flash_mha_fwd_p,
_flash_mha_fwd_cuda_p,
_flash_mha_fwd_cuda_lowering,
platform="gpu",
)
Expand Down Expand Up @@ -176,11 +318,25 @@ def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=Non
return out

mlir.register_lowering(
_flash_mha_bwd_p,
_flash_mha_bwd_cuda_p,
_flash_mha_bwd_cuda_lowering,
platform="gpu",
)

# ==== High level ops ====

mlir.register_lowering(
_flash_mha_fwd_p,
mlir.lower_fun(_flash_mha_fwd_cuda),
platform="gpu",
)

mlir.register_lowering(
_flash_mha_bwd_p,
mlir.lower_fun(_flash_mha_bwd_cuda),
platform="gpu",
)

# ==== Abstract evaluation rules ====

def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, d_og=None, is_causal=None, window_size=None):
Expand All @@ -194,6 +350,7 @@ def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, d_og=None, is_causal=No
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape),
ShapedArray([n, h, l], jnp.float32)
)
_flash_mha_fwd_cuda_p.def_abstract_eval(_flash_mha_fwd_abstract)
_flash_mha_fwd_p.def_abstract_eval(_flash_mha_fwd_abstract)


Expand All @@ -212,6 +369,7 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, d_og=No
ShapedArray(k.shape, k_dtype, named_shape=k.named_shape),
ShapedArray(v.shape, v_dtype, named_shape=v.named_shape),
)
_flash_mha_bwd_cuda_p.def_abstract_eval(_flash_mha_bwd_abstract)
_flash_mha_bwd_p.def_abstract_eval(_flash_mha_bwd_abstract)

# ==== VMap rules ====
Expand Down Expand Up @@ -250,19 +408,19 @@ def custom_vjp(cls, nondiff_argnums=()):
f.defvjp(cls.fwd, cls.bwd)
return f

# Apparently we need nondiff_argnums so that softmax_scale doesn't get
# turned into a Tracer, which we can't use as a static parameter. It
# gets placed at the front of the argument list in bwd.
# Apparently we need nondiff_argnums so that config doesn't get turned
# into Tensors. They get placed at the front of the argument list in
# bwd.
@partial(custom_vjp, nondiff_argnums=(3,))
class _flash_mha_vjp:
def base(q,k,v,config):
return flash_mha_fwd(q,k,v, **config)[0]
return _flash_mha_fwd_p.bind(q,k,v, **config)[0]
def fwd(q,k,v,config):
out, lse = flash_mha_fwd(q,k,v, **config)
out, lse = _flash_mha_fwd_p.bind(q,k,v, **config)
return out, (q,k,v,out,lse)
def bwd(config, pack, dout):
(q,k,v,out,lse) = pack
dq, dk, dv = flash_mha_bwd(dout, q, k, v, out, lse, **config)
dq, dk, dv = _flash_mha_bwd_p.bind(dout, q, k, v, out, lse, **config)
return (dq,dk,dv)

# ==== Frontend ====
Expand All @@ -274,6 +432,10 @@ def flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)):
provided (ie. can't be a tensor or a tracer).0
"""
assert len(q.shape) == 4
assert len(k.shape) == 4
assert len(v.shape) == 4

if softmax_scale is None:
softmax_scale = 1/math.sqrt(q.shape[-1])
assert type(softmax_scale) is float
Expand Down
Loading

0 comments on commit 5a78a51

Please sign in to comment.