-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic CSC and CSR sparse matrix support
- Loading branch information
1 parent
cdef9cc
commit f6a491a
Showing
3 changed files
with
284 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import numpy as np | ||
import scipy as sp | ||
import scipy.sparse | ||
from numba.core import cgutils, types | ||
from numba.core.imputils import impl_ret_borrowed | ||
from numba.extending import ( | ||
NativeValue, | ||
box, | ||
intrinsic, | ||
make_attribute_wrapper, | ||
models, | ||
overload, | ||
overload_attribute, | ||
overload_method, | ||
register_model, | ||
typeof_impl, | ||
unbox, | ||
) | ||
|
||
|
||
class CSMatrixType(types.Type): | ||
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" | ||
|
||
name: str | ||
|
||
@staticmethod | ||
def instance_class(data, indices, indptr, shape): | ||
raise NotImplementedError() | ||
|
||
def __init__(self, dtype): | ||
self.dtype = dtype | ||
self.data = types.Array(dtype, 1, "A") | ||
self.indices = types.Array(types.int32, 1, "A") | ||
self.indptr = types.Array(types.int32, 1, "A") | ||
self.shape = types.UniTuple(types.int64, 2) | ||
super().__init__(self.name) | ||
|
||
@property | ||
def key(self): | ||
return (self.name, self.dtype) | ||
|
||
|
||
make_attribute_wrapper(CSMatrixType, "data", "data") | ||
make_attribute_wrapper(CSMatrixType, "indices", "indices") | ||
make_attribute_wrapper(CSMatrixType, "indptr", "indptr") | ||
make_attribute_wrapper(CSMatrixType, "shape", "shape") | ||
|
||
|
||
class CSRMatrixType(CSMatrixType): | ||
name = "csr_matrix" | ||
|
||
@staticmethod | ||
def instance_class(data, indices, indptr, shape): | ||
return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) | ||
|
||
|
||
class CSCMatrixType(CSMatrixType): | ||
name = "csc_matrix" | ||
|
||
@staticmethod | ||
def instance_class(data, indices, indptr, shape): | ||
return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) | ||
|
||
|
||
@typeof_impl.register(sp.sparse.csc_matrix) | ||
def typeof_csc_matrix(val, c): | ||
data = typeof_impl(val.data, c) | ||
return CSCMatrixType(data.dtype) | ||
|
||
|
||
@typeof_impl.register(sp.sparse.csr_matrix) | ||
def typeof_csr_matrix(val, c): | ||
data = typeof_impl(val.data, c) | ||
return CSRMatrixType(data.dtype) | ||
|
||
|
||
@register_model(CSRMatrixType) | ||
class CSRMatrixModel(models.StructModel): | ||
def __init__(self, dmm, fe_type): | ||
members = [ | ||
("data", fe_type.data), | ||
("indices", fe_type.indices), | ||
("indptr", fe_type.indptr), | ||
("shape", fe_type.shape), | ||
] | ||
super().__init__(dmm, fe_type, members) | ||
|
||
|
||
@register_model(CSCMatrixType) | ||
class CSCMatrixModel(models.StructModel): | ||
def __init__(self, dmm, fe_type): | ||
members = [ | ||
("data", fe_type.data), | ||
("indices", fe_type.indices), | ||
("indptr", fe_type.indptr), | ||
("shape", fe_type.shape), | ||
] | ||
super().__init__(dmm, fe_type, members) | ||
|
||
|
||
@unbox(CSCMatrixType) | ||
@unbox(CSRMatrixType) | ||
def unbox_matrix(typ, obj, c): | ||
|
||
struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) | ||
|
||
data = c.pyapi.object_getattr_string(obj, "data") | ||
indices = c.pyapi.object_getattr_string(obj, "indices") | ||
indptr = c.pyapi.object_getattr_string(obj, "indptr") | ||
shape = c.pyapi.object_getattr_string(obj, "shape") | ||
|
||
struct_ptr.data = c.unbox(typ.data, data).value | ||
struct_ptr.indices = c.unbox(typ.indices, indices).value | ||
struct_ptr.indptr = c.unbox(typ.indptr, indptr).value | ||
struct_ptr.shape = c.unbox(typ.shape, shape).value | ||
|
||
c.pyapi.decref(data) | ||
c.pyapi.decref(indices) | ||
c.pyapi.decref(indptr) | ||
c.pyapi.decref(shape) | ||
|
||
is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) | ||
is_error = c.builder.load(is_error_ptr) | ||
|
||
res = NativeValue(struct_ptr._getvalue(), is_error=is_error) | ||
|
||
return res | ||
|
||
|
||
@box(CSCMatrixType) | ||
@box(CSRMatrixType) | ||
def box_matrix(typ, val, c): | ||
struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) | ||
|
||
data_obj = c.box(typ.data, struct_ptr.data) | ||
indices_obj = c.box(typ.indices, struct_ptr.indices) | ||
indptr_obj = c.box(typ.indptr, struct_ptr.indptr) | ||
shape_obj = c.box(typ.shape, struct_ptr.shape) | ||
|
||
c.pyapi.incref(data_obj) | ||
c.pyapi.incref(indices_obj) | ||
c.pyapi.incref(indptr_obj) | ||
c.pyapi.incref(shape_obj) | ||
|
||
cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) | ||
obj = c.pyapi.call_function_objargs( | ||
cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) | ||
) | ||
|
||
c.pyapi.decref(data_obj) | ||
c.pyapi.decref(indices_obj) | ||
c.pyapi.decref(indptr_obj) | ||
c.pyapi.decref(shape_obj) | ||
|
||
return obj | ||
|
||
|
||
@overload(np.shape) | ||
def overload_sparse_shape(x): | ||
if isinstance(x, CSMatrixType): | ||
return lambda x: x.shape | ||
|
||
|
||
@overload_attribute(CSMatrixType, "ndim") | ||
def overload_sparse_ndim(inst): | ||
|
||
if not isinstance(inst, CSMatrixType): | ||
return | ||
|
||
def ndim(inst): | ||
return 2 | ||
|
||
return ndim | ||
|
||
|
||
@intrinsic | ||
def _sparse_copy(typingctx, inst, data, indices, indptr, shape): | ||
def _construct(context, builder, sig, args): | ||
typ = sig.return_type | ||
struct = cgutils.create_struct_proxy(typ)(context, builder) | ||
_, data, indices, indptr, shape = args | ||
struct.data = data | ||
struct.indices = indices | ||
struct.indptr = indptr | ||
struct.shape = shape | ||
return impl_ret_borrowed( | ||
context, | ||
builder, | ||
sig.return_type, | ||
struct._getvalue(), | ||
) | ||
|
||
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) | ||
|
||
return sig, _construct | ||
|
||
|
||
@overload_method(CSMatrixType, "copy") | ||
def overload_sparse_copy(inst): | ||
|
||
if not isinstance(inst, CSMatrixType): | ||
return | ||
|
||
def copy(inst): | ||
return _sparse_copy( | ||
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape | ||
) | ||
|
||
return copy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numba | ||
import numpy as np | ||
import scipy.sparse | ||
|
||
|
||
def test_sparse_unboxing(): | ||
@numba.njit | ||
def test_unboxing(x, y): | ||
return x.shape, y.shape | ||
|
||
x_val = scipy.sparse.csr_matrix(np.eye(100)) | ||
y_val = scipy.sparse.csc_matrix(np.eye(101)) | ||
|
||
res = test_unboxing(x_val, y_val) | ||
|
||
assert res == (x_val.shape, y_val.shape) | ||
|
||
|
||
def test_sparse_boxing(): | ||
@numba.njit | ||
def test_boxing(x, y): | ||
return x, y | ||
|
||
x_val = scipy.sparse.csr_matrix(np.eye(100)) | ||
y_val = scipy.sparse.csc_matrix(np.eye(101)) | ||
|
||
res_x_val, res_y_val = test_boxing(x_val, y_val) | ||
|
||
assert np.array_equal(res_x_val.data, x_val.data) | ||
assert np.array_equal(res_x_val.indices, x_val.indices) | ||
assert np.array_equal(res_x_val.indptr, x_val.indptr) | ||
assert res_x_val.shape == x_val.shape | ||
|
||
assert np.array_equal(res_y_val.data, y_val.data) | ||
assert np.array_equal(res_y_val.indices, y_val.indices) | ||
assert np.array_equal(res_y_val.indptr, y_val.indptr) | ||
assert res_y_val.shape == y_val.shape | ||
|
||
|
||
def test_sparse_shape(): | ||
@numba.njit | ||
def test_fn(x): | ||
return np.shape(x) | ||
|
||
x_val = scipy.sparse.csr_matrix(np.eye(100)) | ||
|
||
res = test_fn(x_val) | ||
|
||
assert res == (100, 100) | ||
|
||
|
||
def test_sparse_ndim(): | ||
@numba.njit | ||
def test_fn(x): | ||
return x.ndim | ||
|
||
x_val = scipy.sparse.csr_matrix(np.eye(100)) | ||
|
||
res = test_fn(x_val) | ||
|
||
assert res == 2 | ||
|
||
|
||
def test_sparse_copy(): | ||
@numba.njit | ||
def test_fn(x): | ||
y = x.copy() | ||
return ( | ||
y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) | ||
) | ||
|
||
x_val = scipy.sparse.csr_matrix(np.eye(100)) | ||
|
||
assert test_fn(x_val) |