Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update sparse computation, add initial version of BlockCSR and BlockELL #81

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions brainunit/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
# ==============================================================================


from .coo import COO, coo_todense, coo_fromdense
from .csr import CSR, CSC, csr_todense, csr_fromdense, csc_fromdense, csc_todense
from ._coo import COO, coo_todense, coo_fromdense
from ._csr import CSR, CSC, csr_todense, csr_fromdense, csc_fromdense, csc_todense
from .._sparse_base import SparseMatrix

__all__ = [
"SparseMatrix",
"CSR", "CSC",
"csr_todense", "csr_fromdense",
"csc_todense", "csc_fromdense",
Expand Down
240 changes: 240 additions & 0 deletions brainunit/sparse/_block_csr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import functools
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import pallas as pl

import brainunit as u
from brainunit._base import Quantity
from brainunit._sparse_base import SparseMatrix

__all__ = [
'BlockCSR',
]


@jax.tree_util.register_pytree_node_class
class BlockCSR(SparseMatrix):
"""
Unit-aware Block-CSR sparse matrix.
"""
data: jax.Array | Quantity # float32[n_blocks, *block_size]
indptr: jax.Array # int32[n_block_rows + 1]
indices: jax.Array # int32[n_blocks]
shape: tuple[int, int] # (n_block_rows * block_size[0], n_block_cols * block_size[1])

ndim: int = property(lambda self: len(self.shape))
num_blocks = property(lambda self: self.data.shape[0])
block_size = property(lambda self: self.data.shape[1:])
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape: Tuple[int, int]):
blocks, indptr, indices = args
self.data = blocks
self.indptr = indptr
self.indices = indices
super().__init__(args, shape=shape)

def tree_flatten(self):
return (self.data,), (self.indptr, self.indices, self.shape,)

@classmethod
def tree_unflatten(cls, data, xs):
blocks, = xs
indptr, indices, shape = data
return BlockCSR((blocks, indptr, indices), shape=shape)

def _validate(self):
_nblocks, n, m = self.data.shape
nrows = self.indptr.shape[0] - 1
assert self.indices.shape[0] == _nblocks
assert len(self.shape) == 2
assert self.shape[0] == n * nrows
assert self.shape[1] % m == 0

@jax.jit
def todense(self) -> jax.Array:
self._validate()
return _sdd_todense(self)

@classmethod
def fromdense(cls, dense: jax.Array, *, block_size) -> 'BlockCSR':
raise NotImplementedError

def __matmul__(self, other) -> jax.Array:
self._validate()
return sdd_matmul(self, other)


@jax.jit
def _sdd_todense(mat: BlockCSR) -> jax.Array:
_, n, m = mat.data.shape
nrows = mat.shape[0] // n
unit = u.get_unit(mat.data)
blocks = u.get_mantissa(mat.data)

def i_body(i_row, out): # each row
def j_body(x): # each block in the row
i_block, val = x
i_col = mat.indices[i_block]
val = jax.lax.dynamic_update_slice(val, blocks[i_block], (i_row * n, i_col * m))
return i_block + 1, val

return jax.lax.while_loop(
lambda x: x[0] < mat.indptr[i_row + 1],
j_body,
(mat.indptr[i_row], out)
)[1]

dense = jax.lax.fori_loop(0, nrows, i_body, jnp.zeros(mat.shape, mat.dtype))
return u.maybe_decimal(u.Quantity(dense, unit=unit))


def _check_shape_consistency(x, y):
assert isinstance(y, jax.Array), f"Only support jax.Array. But got unsupported type {type(y)}"
assert x.ndim == y.ndim == 2
assert x.shape[1] == y.shape[0], f"Dimension mismatch: {x.shape} @ {y.shape}"


def _sdd_kernel(
x_ref, # [n_blocks, bm, bn]
indices_ref, # [n_block]
indptr_ref, # [n_rows + 1]
y_ref, # [n, k]
o_ref, # [m, k]
*,
bm: int,
bn: int,
bk: int,
):
i_m = pl.program_id(axis=0)
i_k = pl.program_id(axis=1)
i_start = indptr_ref[i_m]
i_end = indptr_ref[i_m + 1]

def body(x):
val, i_block = x
i_x_col = indices_ref[i_block]
block = pl.load(x_ref, (i_block, pl.dslice(None), pl.dslice(None))) # [bm, bn]
chunk = pl.load(y_ref, (pl.dslice(i_x_col * bn, bn), pl.dslice(i_k * bk, bk))) # [bn, bk]
return val + jnp.dot(block, chunk).astype(o_ref.dtype), i_block + 1

acc = jax.lax.while_loop(
lambda x: x[1] < i_end,
body,
(jnp.zeros([bm, bk], dtype=o_ref.dtype), i_start)
)[0]
pl.store(o_ref, (pl.dslice(bm * i_m, bm), pl.dslice(bk * i_k, bk)), acc) # [bm, bk]


@functools.partial(jax.jit, static_argnames=["debug", 'interpret', 'block_size'])
def sdd_matmul(
mat1: BlockCSR,
mat2: jax.Array,
*,
debug: bool = False,
interpret: bool = False,
block_size: int = 256,
) -> jax.Array:
_check_shape_consistency(mat1, mat2)

# shape and dtype
m, n, k = mat1.shape[0], mat1.shape[1], mat2.shape[1]
_, bm, bn = mat1.data.shape
dtype = jnp.result_type(mat1.dtype, mat2.dtype)

# kernel
fn = pl.pallas_call(
functools.partial(_sdd_kernel, bm=bm, bn=bn, bk=block_size),
out_shape=jax.ShapeDtypeStruct(shape=(m, k), dtype=dtype),
grid=(pl.cdiv(m, bm), pl.cdiv(k, block_size)),
debug=debug,
interpret=interpret
)

# call
unita = u.get_unit(mat1.data)
unitb = u.get_unit(mat2)
blocks = u.get_mantissa(mat1.data)
r = fn(blocks, mat1.indices, mat1.indptr, u.get_mantissa(mat2))
return u.maybe_decimal(u.Quantity(r, unit=unita * unitb))


@jax.jit
def native_sdd_matmul(
mat1: BlockCSR,
mat2: jax.Array,
):
_check_shape_consistency(mat1, mat2)

dtype = jnp.result_type(mat1.dtype, mat2.dtype)
_, n, m = mat1.data.shape

nrows = mat1.shape[0] // n

def i_body(i): # each row
def k_body(x):
i_block, val = x
i_col = mat1.indices[i_block]
chunk = jax.lax.dynamic_slice(mat2, [i_col * m, 0], (m, mat2.shape[1])) # [m, mat2.shape[1]]
block = blocks[i_block]
return i_block + 1, val + block.dot(chunk)

acc = jax.lax.while_loop(
lambda x: x[0] < mat1.indptr[i + 1],
k_body,
(mat1.indptr[i], jnp.zeros((n, mat2.shape[1]), dtype=jnp.float32))
)[1]
return acc.astype(dtype)

unita = u.get_unit(mat1.data)
unitb = u.get_unit(mat2)
blocks = u.get_mantissa(mat1.data)
mat2 = u.get_mantissa(mat2)

out = jax.vmap(i_body)(jnp.arange(nrows)).reshape((mat1.shape[0], mat2.shape[1]))
return u.maybe_decimal(u.Quantity(out, unit=unita * unitb))


def sample_sparse_matrix(
m, n, bm, bn, *,
sparse_prob=0.2,
dtype=jnp.float32
) -> BlockCSR:
num_rows = m // bm # number of rows in the Block-ELL matrix
num_cols = n // bn # number of columns in the Block-ELL matrix
blocks_per_row = np.random.binomial(num_cols, sparse_prob,
size=[num_rows]) # [n_rows], number of data in each row
num_blocks = blocks_per_row.sum()
blocks = np.random.randn(num_blocks, bm, bn).astype(dtype) # [n_blocks, bm, bk], block values

# [n_rows + 1], row pointers
indptr = np.zeros(num_rows + 1, dtype=np.int32) # [n_rows + 1], row pointers
indptr[1:] = np.cumsum(blocks_per_row)

# [n_block], block indices
indices = []
for i in range(num_rows):
indices.extend(np.random.choice(num_cols, blocks_per_row[i], replace=False))
indices = jnp.array(indices) # [n_rows, max_num_blocks_per_row, 2], block indices

return BlockCSR((jnp.asarray(blocks), jnp.asarray(indptr), jnp.asarray(indices)), shape=(m, n))
99 changes: 99 additions & 0 deletions brainunit/sparse/_block_csr_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import functools
import timeit

import brainstate as bst
import jax
import jax.numpy as jnp

from brainunit.sparse._block_csr import sample_sparse_matrix, sdd_matmul, native_sdd_matmul


def main(dtype=jnp.float16, sparse_prob=0.2):
bst.random.seed(1234)
# data
m, k, n = 4096, 4096, 4096
bm, bn, bk = 32, 32, 256
print(f"Matrix Shape: {m} x {k} x {n}, dtype: {dtype}, sparse_prob: {sparse_prob}")

x = sample_sparse_matrix(m, k, bm, bn, sparse_prob=sparse_prob, dtype=dtype)
x_dense = x.todense()
y = bst.random.randn(k, n, dtype=dtype)

# operations
interpret = jax.devices()[0].platform == "cpu"
# sdd_matmul(x, y, debug=False, block_size=bk, interpret=interpret).block_until_ready()
native_matmul = jax.jit(native_sdd_matmul)
pl_matmul = jax.jit(functools.partial(sdd_matmul, block_size=bk, interpret=interpret))
dense_matmul = jax.jit(jnp.matmul)
native_grad = jax.jit(jax.grad(native_sdd_matmul, argnums=(0, 1)))
pl_grad = jax.jit(jax.grad(functools.partial(sdd_matmul, block_size=bk, interpret=interpret), argnums=(0, 1)))
dense_grad = jax.jit(jax.grad(jnp.matmul, argnums=(0, 1)))

# compilation
out_pl = pl_matmul(x, y)
out_hlo = native_matmul(x, y)
out_ref = dense_matmul(x_dense, y)
# out_pl = pl_grad(x, y)
# out_hlo = native_grad(x, y)
# out_ref = dense_grad(x_dense, y)

# print(jnp.max(jnp.abs(out_pl - out_ref)))
# print(jnp.max(jnp.abs(out_pl - out_ref) / jnp.abs(out_pl)))
# np.testing.assert_allclose(out_pl, out_ref, atol=0.04, rtol=0.04)
# np.testing.assert_allclose(out_hlo, out_ref, atol=0.04, rtol=0.04)

n_trial1, n_trial2 = (10, 2) if interpret else (1000, 20)
duration = timeit.timeit(lambda: dense_matmul(x_dense, y).block_until_ready(), number=n_trial1)
s1_forward = duration / n_trial1 * 1000
print(f"Dense Matmul, forward: {s1_forward:.2f}ms")
# duration = timeit.timeit(lambda: jax.block_until_ready(dense_grad(x, y)), number=n_trial1)
# s1_backward = duration / n_trial1 * 1000
# print(f"Dense Matmul, backward: {s1_backward:.2f}ms")

duration = timeit.timeit(lambda: pl_matmul(x, y).block_until_ready(), number=n_trial1)
s2_forward = duration / n_trial1 * 1000
print(f"Pallas Blocksparse Matmul, forward: {s2_forward:.2f}ms")
# duration = timeit.timeit(lambda: jax.block_until_ready(pl_grad(x, y)), number=n_trial1)
# s2_backward = duration / n_trial1 * 1000
# print(f"Pallas Blocksparse Matmul, backward: {s2_backward:.2f}ms")

duration = timeit.timeit(lambda: native_matmul(x, y).block_until_ready(), number=n_trial2)
s3_forward = duration / n_trial2 * 1000
print(f"HLO Blocksparse Matmul, forward: {s3_forward:.2f}ms")
# duration = timeit.timeit(lambda: jax.block_until_ready(native_grad(x, y)), number=n_trial2)
# s3_backward = duration / n_trial2 * 1000
# print(f"HLO Blocksparse Matmul, backward: {s3_backward:.2f}ms")

print(f"Forward speedup: {s1_forward / s2_forward:.2f}x (Dense vs. Pallas), "
f"{s3_forward / s2_forward:.2f}x (HLO vs. Pallas)")
# print(f"Backward speedup: {s1_backward / s2_backward:.2f}x (Dense vs. Pallas), "
# f"{s3_backward / s2_backward:.2f}x (HLO vs. Pallas)")
print()


if __name__ == "__main__":
main(jnp.float32, 0.3)
main(jnp.float32, 0.2)
main(jnp.float32, 0.1)
main(jnp.float32, 0.05)
main(jnp.float16, 0.3)
main(jnp.float16, 0.2)
main(jnp.float16, 0.1)
main(jnp.float16, 0.05)
Loading
Loading