Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.37】为 Paddle 新增 householder_product API -part #58214

Merged
merged 21 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions python/paddle/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .tensor import inverse as inv
from .tensor.linalg import (
cholesky,
cholesky_solve,
cond,
corrcoef,
cov,
det,
eig,
eigh,
eigvals,
eigvalsh,
lstsq,
lu,
lu_unpack,
matrix_power,
matrix_rank,
multi_dot,
norm,
pca_lowrank,
pinv,
qr,
slogdet,
solve,
svd,
triangular_solve,
)
from .tensor import inverse as inv # noqa: F401
from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import cholesky_solve # noqa: F401
from .tensor.linalg import cond # noqa: F401
from .tensor.linalg import corrcoef # noqa: F401
from .tensor.linalg import cov # noqa: F401
from .tensor.linalg import det # noqa: F401
from .tensor.linalg import eig # noqa: F401
from .tensor.linalg import eigh # noqa: F401
from .tensor.linalg import eigvals # noqa: F401
from .tensor.linalg import eigvalsh # noqa: F401
from .tensor.linalg import householder_product # noqa: F401
from .tensor.linalg import lu # noqa: F401
from .tensor.linalg import lu_unpack # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import matrix_rank # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import pca_lowrank # noqa: F401
from .tensor.linalg import pinv # noqa: F401
from .tensor.linalg import qr # noqa: F401
from .tensor.linalg import slogdet # noqa: F401
from .tensor.linalg import solve # noqa: F401
from .tensor.linalg import svd # noqa: F401
from .tensor.linalg import triangular_solve # noqa: F401
from .tensor.linalg import lstsq

__all__ = [
'cholesky',
Expand All @@ -53,6 +52,7 @@
'matrix_rank',
'svd',
'qr',
'householder_product',
'pca_lowrank',
'lu',
'lu_unpack',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from .linalg import eig # noqa: F401
from .linalg import matrix_power # noqa: F401
from .linalg import qr # noqa: F401
from .linalg import householder_product # noqa: F401
from .linalg import eigvals # noqa: F401
from .linalg import multi_dot # noqa: F401
from .linalg import svd # noqa: F401
Expand Down Expand Up @@ -413,6 +414,7 @@
'mv',
'matrix_power',
'qr',
'householder_product',
'pca_lowrank',
'eigvals',
'eigvalsh',
Expand Down
117 changes: 117 additions & 0 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3724,3 +3724,120 @@ def cdist(
return paddle.linalg.norm(
x[..., None, :] - y[..., None, :, :], p=p, axis=-1
)


def householder_product(A, tau, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to API naming conventions, enter the name of Tensor using x, and the rfc should also be modified synchronously

r"""

Computes the first n columns of a product of Householder matrices.

This function can get the vector :math:`\omega_{i}` from matrix `A`(m x n), the :math:`i-1` elements are zeros, and the i-th is `1`, the rest of the elements are from i-th column of `A`.
And with the vector `tau` can calculate the first n columns of a product of Householder matrices.

:math:`H_i = I_m - \tau_i \omega_i \omega_i^H`

Args:
A (Tensor): A tensor with shape (*, m, n) where * is zero or more batch dimensions.
tau (Tensor): A tensor with shape (*, k) where * is zero or more batch dimensions.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Tensor, the dtype is same as input tensor, the Q in QR decomposition.

:math:`out = Q = H_1H_2H_3...H_k`

Examples:
.. code-block:: python

>>> import paddle
>>> A = paddle.to_tensor([[-1.1280, 0.9012, -0.0190],
... [ 0.3699, 2.2133, -1.4792],
... [ 0.0308, 0.3361, -3.1761],
... [-0.0726, 0.8245, -0.3812]])
>>> tau = paddle.to_tensor([1.7497, 1.1156, 1.7462])
>>> Q = paddle.linalg.householder_product(A, tau)
>>> Q
cocoshe marked this conversation as resolved.
Show resolved Hide resolved
Tensor(shape=[4, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[-0.74969995, -0.02181768, 0.31115776],
[-0.64721400, -0.12367040, -0.21738708],
[-0.05389076, -0.37562513, -0.84836429],
[ 0.12702821, -0.91822827, 0.36892807]])
"""

check_dtype(
A.dtype,
'x',
[
'float32',
'float64',
],
'householder_product',
)
check_dtype(
tau.dtype,
'tau',
[
'float32',
'float64',
],
'householder_product',
)
assert (
A.dtype == tau.dtype
), "The input A must have the same dtype with input tau.\n"
assert (
len(A.shape) >= 2
and len(tau.shape) >= 1
and len(A.shape) == len(tau.shape) + 1
), (
"The input A must have more than 2 dimensions, and input tau must have more than 1 dimension,"
"and the dimension of A is 1 larger than the dimension of tau\n"
)
assert (
A.shape[-2] >= A.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是保证A矩阵的m >= n吗,tau矩阵是否要保证 n >= k?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

householder_product当前实现方案支持了实值,torch中有对复数的支持吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦review~

这里是保证A矩阵的m >= n吗,tau矩阵是否要保证 n >= k?

是的,这里应该在加一些assert和对应的单测,我稍后补上,限制条件发现了一些小bug

householder_product当前实现方案支持了实值,torch中有对复数的支持吗?

是的,我当时直接用了paddle.norm(它暂时没支持复数实现),不过应该可以直接手动实现一下norm,我稍后试下再补上相应单测~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

), "The rows of input A must be greater than or equal to the columns of input A.\n"
for idx, _ in enumerate(A.shape[:-2]):
assert (
A.shape[idx] == tau.shape[idx]
), "The input A must have the same batch dimensions with input tau.\n"

def _householder_product(A, tau):
m, n = A.shape[-2:]
Q = paddle.eye(m)
for i in range(min(m, n)):
normx = paddle.norm(A[i:, i])
sign = 1 if A[i, i] < 0 else -1
w = A[i:, i]
if in_dynamic_mode():
w[0] = 1
else:
w = paddle.static.setitem(w, 0, 1)
w = w.reshape([-1, 1])
if in_dynamic_mode():
Q[:, i:] = Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i])
else:
Q = paddle.static.setitem(
Q,
(slice(None), slice(i, None)),
Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]),
)
return Q[:, :n]

if len(A.shape) == 2:
return _householder_product(A, tau)
m, n = A.shape[-2:]
org_A_shape = A.shape
org_tau_shape = tau.shape
A = A.reshape((-1, org_A_shape[-2], org_A_shape[-1]))
tau = tau.reshape((-1, org_tau_shape[-1]))
n_batch = A.shape[0]
out = paddle.zeros([n_batch, m, n])
for i in range(n_batch):
if in_dynamic_mode():
out[i] = _householder_product(A[i], tau[i])
else:
out = paddle.static.setitem(
out, i, _householder_product(A[i], tau[i])
)
out = out.reshape(org_A_shape)
return out
Loading