Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] add op linalg solve #16913

Merged
merged 1 commit into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _op as _mx_nd_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -352,3 +352,57 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)


def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.

Computes the "exact" solution, `x`, of the well-determined, i.e., full
rank, linear matrix equation `ax = b`.

Parameters
----------
a : (..., M, M) ndarray
reminisce marked this conversation as resolved.
Show resolved Hide resolved
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.

Returns
-------
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to `b`.

Raises
------
MXNetError
If `a` is singular or not square.

Notes
-----
Broadcasting rules apply, see the `numpy.linalg` documentation for
details.

The solutions are computed using LAPACK routine ``_gesv``.

`a` must be square and of full-rank, i.e., all rows (or, equivalently,
columns) must be linearly independent; if either is not true, use
`lstsq` for the least-squares best "solution" of the
system/equation.

Examples
--------
Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``:

>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([2., 3.])

Check that the solution is correct:

>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
56 changes: 55 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -370,3 +370,57 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _mx_nd_np.linalg.slogdet(a)

haojin2 marked this conversation as resolved.
Show resolved Hide resolved

def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.

Computes the "exact" solution, `x`, of the well-determined, i.e., full
rank, linear matrix equation `ax = b`.

Parameters
----------
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.

Returns
-------
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to `b`.

Raises
------
MXNetError
If `a` is singular or not square.

Notes
-----
Broadcasting rules apply, see the `numpy.linalg` documentation for
details.

The solutions are computed using LAPACK routine ``_gesv``.

`a` must be square and of full-rank, i.e., all rows (or, equivalently,
columns) must be linearly independent; if either is not true, use
`lstsq` for the least-squares best "solution" of the
system/equation.

Examples
--------
Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``:

>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([2., 3.])

Check that the solution is correct:

>>> np.allclose(np.dot(a, x), b)
True
"""
return _mx_nd_np.linalg.solve(a, b)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'linalg.norm',
'linalg.cholesky',
'linalg.inv',
'linalg.solve',
'shape',
'trace',
'tril',
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import _op as _mx_sym_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -339,3 +339,56 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)

def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.

Computes the "exact" solution, `x`, of the well-determined, i.e., full
rank, linear matrix equation `ax = b`.

Parameters
----------
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.

Returns
-------
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to `b`.

Raises
------
MXNetError
If `a` is singular or not square.

Notes
-----
Broadcasting rules apply, see the `numpy.linalg` documentation for
details.

The solutions are computed using LAPACK routine ``_gesv``.

`a` must be square and of full-rank, i.e., all rows (or, equivalently,
columns) must be linearly independent; if either is not true, use
`lstsq` for the least-squares best "solution" of the
system/equation.

Examples
--------
Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``:

>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([2., 3.])

Check that the solution is correct:

>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
10 changes: 10 additions & 0 deletions src/operator/c_lapack_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@
return 1; \
}

#define MXNET_LAPACK_CWRAPPER7(func, dtype) \
int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
int lda, int *ipiv, dtype *b, int ldb) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
Expand Down Expand Up @@ -101,4 +108,7 @@
MXNET_LAPACK_CWRAPPER6(sgesvd, float)
MXNET_LAPACK_CWRAPPER6(dgesvd, double)

MXNET_LAPACK_CWRAPPER7(sgesv, float)
MXNET_LAPACK_CWRAPPER7(dgesv, double)

#endif // MSHADOW_USE_MKL == 0
39 changes: 37 additions & 2 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ extern "C" {

MXNET_LAPACK_FSIG_GETRI(sgetri, float)
MXNET_LAPACK_FSIG_GETRI(dgetri, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GESV(func, dtype) \
int func##_(int *n, int *nrhs, dtype *a, int *lda, \
int *ipiv, dtype *b, int *ldb, int *info);
#else
#define MXNET_LAPACK_FSIG_GESV(func, dtype) \
void func##_(int *n, int *nrhs, dtype *a, int *lda, \
int *ipiv, dtype *b, int *ldb, int *info);
#endif

MXNET_LAPACK_FSIG_GESV(sgesv, float)
MXNET_LAPACK_FSIG_GESV(dgesv, double)
}

#endif // MSHADOW_USE_MKL == 0
Expand Down Expand Up @@ -197,6 +210,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
#define MXNET_LAPACK_dpotri LAPACKE_dpotri
#define mxnet_lapack_sposv LAPACKE_sposv
#define mxnet_lapack_dposv LAPACKE_dposv
#define MXNET_LAPACK_dgesv LAPACKE_dgesv
#define MXNET_LAPACK_sgesv LAPACKE_sgesv

// The following functions differ in signature from the
// MXNET_LAPACK-signature and have to be wrapped.
Expand Down Expand Up @@ -440,9 +455,23 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#else

#define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \
int n, int nrhs, dtype *a, int lda, \
int *ipiv, dtype *b, int ldb) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GESV(s, float)
MXNET_LAPACK_CWRAP_GESV(d, double)

#else

#define MXNET_LAPACK_ROW_MAJOR 101
#define MXNET_LAPACK_COL_MAJOR 102
Expand Down Expand Up @@ -473,6 +502,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
int ldut, dtype* s, dtype* v, int ldv, \
dtype* work, int lwork);

#define MXNET_LAPACK_CWRAPPER7(func, dtype) \
int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
int lda, int *ipiv, dtype *b, int ldb); \

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...);
Expand Down Expand Up @@ -501,6 +533,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAPPER6(sgesvd, float)
MXNET_LAPACK_CWRAPPER6(dgesvd, double)

MXNET_LAPACK_CWRAPPER7(sgesv, float)
MXNET_LAPACK_CWRAPPER7(dgesv, double)

#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
Expand Down
Loading