Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes out kwarg in matmul when axes are appended to inputs #1610

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions dpctl/tensor/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,16 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
matrices on which to perform matrix multiplication.
out (Optional[usm_ndarray]):
the array into which the result of the matrix product is written.
If `None` then a new array is returned.
The data type of `out` must match the expected data type of the
result or (if provided) `dtype`.
If `None` then a new array is returned. Default: `None`.
dtype (Optional[dtype]):
data type of the returned array. If `None`, the data type of the
returned array is determined by the Type Promotion Rules.
Default: `None`.
order (["K", "C", "F", "A"]):
memory layout of the output array, if `out` is `None`, otherwise
the `order` parameter value is not used.

the `order` parameter value is not used. Default: `K`.
Returns:
usm_ndarray:
* if both `x1` and `x2` are one-dimensional arrays with shape
Expand All @@ -613,8 +618,8 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
a two-dimensional array with shape `(K, N)`, returned array is a
two-dimensional array with shape `(M, N)` and contains the
conventional matrix product.
* if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an
array with shape `(..., K, N)`, returned array contains the
* if `x1` is a one-dimensional array with shape `(K,)` and `x2` is
an array with shape `(..., K, N)`, returned array contains the
conventional matrix product and has shape `(..., N)`.
* if `x1` is an array with shape `(..., M, K)` and `x2` is a
one-dimensional array with shape `(K,)`, returned array has shape
Expand Down Expand Up @@ -741,12 +746,21 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
final_res_shape = tuple(
res_shape[i]
for i in range(-len(res_shape), 0)
if i not in appended_axes
)
if out.shape != final_res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
f"Expected output shape is {res_shape}, got {out.shape}"
f"Expected output shape is {final_res_shape}, got {out.shape}"
)

if appended_axes:
out = dpt.expand_dims(out, appended_axes)
orig_out = out

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed," f"got {out.dtype}"
Expand Down
25 changes: 25 additions & 0 deletions dpctl/tests/test_usm_ndarray_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,3 +980,28 @@ def test_vecdot_contig_small():
res = dpt.vecdot(v1, v2)
assert dpt.all(res[:-1] == 0)
assert res[-1] == n


def test_matmul_out_appended_axes():
get_queue_or_skip()

n0, n1, n2 = 4, 10, 5
# vm
x1 = dpt.ones(n1, dtype="i4")
x2 = dpt.ones((n0, n1, n2), dtype="i4")
out = dpt.empty((n0, n2), dtype="i4")

dpt.matmul(x1, x2, out=out)
assert dpt.all(out == n1)

# mv
x2 = x2.mT
x1, x2 = x2, x1
dpt.matmul(x1, x2, out=out)
assert dpt.all(out == n1)

# vv
x1 = dpt.ones(n1, dtype="i4")
out = dpt.empty((), dtype="i4")
dpt.matmul(x1, x2, out=out)
assert out == n1