-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.37】为 Paddle 新增 householder_product API -part #58214
Changes from 7 commits
5f627a5
1dcde7a
ce7f286
708fd5d
75c9e06
3f30938
1c49aa4
8899ce7
f477ba5
3bedcfa
d43aaba
173a4cd
ac67f4a
1645665
8bfcd24
b5559d3
b92a244
67c350d
8f64185
77aafeb
d8c3ef1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3724,3 +3724,120 @@ def cdist( | |
return paddle.linalg.norm( | ||
x[..., None, :] - y[..., None, :, :], p=p, axis=-1 | ||
) | ||
|
||
|
||
def householder_product(A, tau, name=None): | ||
r""" | ||
|
||
Computes the first n columns of a product of Householder matrices. | ||
|
||
This function can get the vector :math:`\omega_{i}` from matrix `A`(m x n), the :math:`i-1` elements are zeros, and the i-th is `1`, the rest of the elements are from i-th column of `A`. | ||
And with the vector `tau` can calculate the first n columns of a product of Householder matrices. | ||
|
||
:math:`H_i = I_m - \tau_i \omega_i \omega_i^H` | ||
|
||
Args: | ||
A (Tensor): A tensor with shape (*, m, n) where * is zero or more batch dimensions. | ||
tau (Tensor): A tensor with shape (*, k) where * is zero or more batch dimensions. | ||
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. | ||
|
||
Returns: | ||
Tensor, the dtype is same as input tensor, the Q in QR decomposition. | ||
|
||
:math:`out = Q = H_1H_2H_3...H_k` | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
>>> A = paddle.to_tensor([[-1.1280, 0.9012, -0.0190], | ||
... [ 0.3699, 2.2133, -1.4792], | ||
... [ 0.0308, 0.3361, -3.1761], | ||
... [-0.0726, 0.8245, -0.3812]]) | ||
>>> tau = paddle.to_tensor([1.7497, 1.1156, 1.7462]) | ||
>>> Q = paddle.linalg.householder_product(A, tau) | ||
>>> Q | ||
cocoshe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Tensor(shape=[4, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
[[-0.74969995, -0.02181768, 0.31115776], | ||
[-0.64721400, -0.12367040, -0.21738708], | ||
[-0.05389076, -0.37562513, -0.84836429], | ||
[ 0.12702821, -0.91822827, 0.36892807]]) | ||
""" | ||
|
||
check_dtype( | ||
A.dtype, | ||
'x', | ||
[ | ||
'float32', | ||
'float64', | ||
], | ||
'householder_product', | ||
) | ||
check_dtype( | ||
tau.dtype, | ||
'tau', | ||
[ | ||
'float32', | ||
'float64', | ||
], | ||
'householder_product', | ||
) | ||
assert ( | ||
A.dtype == tau.dtype | ||
), "The input A must have the same dtype with input tau.\n" | ||
assert ( | ||
len(A.shape) >= 2 | ||
and len(tau.shape) >= 1 | ||
and len(A.shape) == len(tau.shape) + 1 | ||
), ( | ||
"The input A must have more than 2 dimensions, and input tau must have more than 1 dimension," | ||
"and the dimension of A is 1 larger than the dimension of tau\n" | ||
) | ||
assert ( | ||
A.shape[-2] >= A.shape[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是保证A矩阵的m >= n吗,tau矩阵是否要保证 n >= k? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. householder_product当前实现方案支持了实值,torch中有对复数的支持吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 辛苦review~
是的,这里应该在加一些assert和对应的单测,我稍后补上,限制条件发现了一些小bug
是的,我当时直接用了paddle.norm(它暂时没支持复数实现),不过应该可以直接手动实现一下norm,我稍后试下再补上相应单测~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
), "The rows of input A must be greater than or equal to the columns of input A.\n" | ||
for idx, _ in enumerate(A.shape[:-2]): | ||
assert ( | ||
A.shape[idx] == tau.shape[idx] | ||
), "The input A must have the same batch dimensions with input tau.\n" | ||
|
||
def _householder_product(A, tau): | ||
m, n = A.shape[-2:] | ||
Q = paddle.eye(m) | ||
for i in range(min(m, n)): | ||
normx = paddle.norm(A[i:, i]) | ||
sign = 1 if A[i, i] < 0 else -1 | ||
w = A[i:, i] | ||
if in_dynamic_mode(): | ||
w[0] = 1 | ||
else: | ||
w = paddle.static.setitem(w, 0, 1) | ||
w = w.reshape([-1, 1]) | ||
if in_dynamic_mode(): | ||
Q[:, i:] = Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]) | ||
else: | ||
Q = paddle.static.setitem( | ||
Q, | ||
(slice(None), slice(i, None)), | ||
Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]), | ||
) | ||
return Q[:, :n] | ||
|
||
if len(A.shape) == 2: | ||
return _householder_product(A, tau) | ||
m, n = A.shape[-2:] | ||
org_A_shape = A.shape | ||
org_tau_shape = tau.shape | ||
A = A.reshape((-1, org_A_shape[-2], org_A_shape[-1])) | ||
tau = tau.reshape((-1, org_tau_shape[-1])) | ||
n_batch = A.shape[0] | ||
out = paddle.zeros([n_batch, m, n]) | ||
for i in range(n_batch): | ||
if in_dynamic_mode(): | ||
out[i] = _householder_product(A[i], tau[i]) | ||
else: | ||
out = paddle.static.setitem( | ||
out, i, _householder_product(A[i], tau[i]) | ||
) | ||
out = out.reshape(org_A_shape) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to API naming conventions, enter the name of Tensor using
x
, and the rfc should also be modified synchronously