Skip to content

Commit

Permalink
Update the field names for the qr() and svd() namedtuples
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Nov 11, 2021
1 parent 9c9ffe1 commit df5f6f5
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,29 +399,29 @@ def test_qr(x, kw):
M, N = x.shape[-2:]
K = min(M, N)

_test_namedtuple(res, ['q', 'r'], 'qr')
q = res.q
r = res.r
_test_namedtuple(res, ['Q', 'R'], 'qr')
Q = res.Q
R = res.R

assert q.dtype == x.dtype, "qr().q did not return the correct dtype"
assert Q.dtype == x.dtype, "qr().Q did not return the correct dtype"
if mode == 'complete':
assert q.shape == x.shape[:-2] + (M, M), "qr().q did not return the correct shape"
assert Q.shape == x.shape[:-2] + (M, M), "qr().Q did not return the correct shape"
else:
assert q.shape == x.shape[:-2] + (M, K), "qr().q did not return the correct shape"
assert Q.shape == x.shape[:-2] + (M, K), "qr().Q did not return the correct shape"

assert r.dtype == x.dtype, "qr().r did not return the correct dtype"
assert R.dtype == x.dtype, "qr().R did not return the correct dtype"
if mode == 'complete':
assert r.shape == x.shape[:-2] + (M, N), "qr().r did not return the correct shape"
assert R.shape == x.shape[:-2] + (M, N), "qr().R did not return the correct shape"
else:
assert r.shape == x.shape[:-2] + (K, N), "qr().r did not return the correct shape"
assert R.shape == x.shape[:-2] + (K, N), "qr().R did not return the correct shape"

_test_stacks(lambda x: linalg.qr(x, **kw).q, x, res=q)
_test_stacks(lambda x: linalg.qr(x, **kw).r, x, res=r)
_test_stacks(lambda x: linalg.qr(x, **kw).Q, x, res=Q)
_test_stacks(lambda x: linalg.qr(x, **kw).R, x, res=R)

# TODO: Test that q is orthonormal
# TODO: Test that Q is orthonormal

# Check that r is upper-triangular.
assert_exactly_equal(r, _array_module.triu(r))
# Check that R is upper-triangular.
assert_exactly_equal(R, _array_module.triu(R))

@pytest.mark.xp_extension('linalg')
@given(
Expand Down Expand Up @@ -506,29 +506,29 @@ def test_svd(x, kw):
*stack, M, N = x.shape
K = min(M, N)

_test_namedtuple(res, ['u', 's', 'vh'], 'svd')
_test_namedtuple(res, ['U', 'S', 'Vh'], 'svd')

u, s, vh = res
U, S, Vh = res

assert u.dtype == x.dtype, "svd().u did not return the correct dtype"
assert s.dtype == x.dtype, "svd().s did not return the correct dtype"
assert vh.dtype == x.dtype, "svd().vh did not return the correct dtype"
assert U.dtype == x.dtype, "svd().U did not return the correct dtype"
assert S.dtype == x.dtype, "svd().S did not return the correct dtype"
assert Vh.dtype == x.dtype, "svd().Vh did not return the correct dtype"

if full_matrices:
assert u.shape == (*stack, M, M), "svd().u did not return the correct shape"
assert vh.shape == (*stack, N, N), "svd().vh did not return the correct shape"
assert U.shape == (*stack, M, M), "svd().U did not return the correct shape"
assert Vh.shape == (*stack, N, N), "svd().Vh did not return the correct shape"
else:
assert u.shape == (*stack, M, K), "svd(full_matrices=False).u did not return the correct shape"
assert vh.shape == (*stack, K, N), "svd(full_matrices=False).vh did not return the correct shape"
assert s.shape == (*stack, K), "svd().s did not return the correct shape"
assert U.shape == (*stack, M, K), "svd(full_matrices=False).U did not return the correct shape"
assert Vh.shape == (*stack, K, N), "svd(full_matrices=False).Vh did not return the correct shape"
assert S.shape == (*stack, K), "svd().S did not return the correct shape"

# The values of s must be sorted from largest to smallest
if K >= 1:
assert _array_module.all(s[..., :-1] >= s[..., 1:]), "svd().s values are not sorted from largest to smallest"
assert _array_module.all(S[..., :-1] >= S[..., 1:]), "svd().S values are not sorted from largest to smallest"

_test_stacks(lambda x: linalg.svd(x, **kw).u, x, res=u)
_test_stacks(lambda x: linalg.svd(x, **kw).s, x, dims=1, res=s)
_test_stacks(lambda x: linalg.svd(x, **kw).vh, x, res=vh)
_test_stacks(lambda x: linalg.svd(x, **kw).U, x, res=U)
_test_stacks(lambda x: linalg.svd(x, **kw).S, x, dims=1, res=S)
_test_stacks(lambda x: linalg.svd(x, **kw).Vh, x, res=Vh)

@pytest.mark.xp_extension('linalg')
@given(
Expand Down

0 comments on commit df5f6f5

Please sign in to comment.