Skip to content

Commit

Permalink
Merge pull request #1786 from IntelPython/fix-gh-1769
Browse files Browse the repository at this point in the history
Fix gh-1769 - setting shape by a scalar
  • Loading branch information
oleksandr-pavlyk authored Aug 6, 2024
2 parents 31486b8 + 9173fed commit b9242f4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
22 changes: 21 additions & 1 deletion dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,22 @@ cdef class usm_ndarray:

@shape.setter
def shape(self, new_shape):
"""
Modifies usm_ndarray instance in-place by changing its metadata
about the shape and the strides of the array, or raises
`AttributeError` exception if in-place change is not possible.
Args:
new_shape: (tuple, int)
New shape. Only non-negative values are supported.
The new shape may not lead to the change in the
number of elements in the array.
Whether the array can be reshape in-place depends on its
strides. Use :func:`dpctl.tensor.reshape` function which
always succeeds to reshape the array by performing a copy
if necessary.
"""
cdef int new_nd = -1
cdef Py_ssize_t nelems = -1
cdef int err = 0
Expand All @@ -576,7 +592,11 @@ cdef class usm_ndarray:

from ._reshape import reshaped_strides

new_nd = len(new_shape)
try:
new_nd = len(new_shape)
except TypeError:
new_nd = 1
new_shape = (new_shape,)
try:
new_shape = tuple(operator.index(dim) for dim in new_shape)
except TypeError:
Expand Down
43 changes: 21 additions & 22 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ctypes
import numbers
from math import prod

import numpy as np
import pytest
Expand Down Expand Up @@ -1102,7 +1103,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
skip_if_dtype_not_supported(dtype, q)
shape = (2, 4, 3)
Xnp = (
np.random.randint(-10, 10, size=np.prod(shape))
np.random.randint(-10, 10, size=prod(shape))
.astype(dtype)
.reshape(shape)
)
Expand Down Expand Up @@ -1307,6 +1308,10 @@ def relaxed_strides_equal(st1, st2, sh):
X = dpt.usm_ndarray(sh_s, dtype="?")
X.shape = sh_f
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
sz = X.size
X.shape = sz
assert X.shape == (sz,)
assert relaxed_strides_equal(X.strides, (1,), (sz,))

X = dpt.usm_ndarray(sh_s, dtype="u4")
with pytest.raises(TypeError):
Expand Down Expand Up @@ -2077,11 +2082,9 @@ def test_tril(dtype):
skip_if_dtype_not_supported(dtype, q)

shape = (2, 3, 4, 5, 5)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
)
X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape)
Y = dpt.tril(X)
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape)
Ynp = np.tril(Xnp)
assert Y.dtype == Ynp.dtype
assert np.array_equal(Ynp, dpt.asnumpy(Y))
Expand All @@ -2093,11 +2096,9 @@ def test_triu(dtype):
skip_if_dtype_not_supported(dtype, q)

shape = (4, 5)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
)
X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape)
Y = dpt.triu(X, k=1)
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape)
Ynp = np.triu(Xnp, k=1)
assert Y.dtype == Ynp.dtype
assert np.array_equal(Ynp, dpt.asnumpy(Y))
Expand All @@ -2110,7 +2111,7 @@ def test_tri_usm_type(tri_fn, usm_type):
dtype = dpt.uint16

shape = (2, 3, 4, 5, 5)
size = np.prod(shape)
size = prod(shape)
X = dpt.reshape(
dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape
)
Expand All @@ -2129,11 +2130,11 @@ def test_tril_slice():
q = get_queue_or_skip()

shape = (6, 10)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
)[1:, ::-2]
X = dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape)[
1:, ::-2
]
Y = dpt.tril(X)
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape)[1:, ::-2]
Xnp = np.arange(prod(shape), dtype="int").reshape(shape)[1:, ::-2]
Ynp = np.tril(Xnp)
assert Y.dtype == Ynp.dtype
assert np.array_equal(Ynp, dpt.asnumpy(Y))
Expand All @@ -2144,14 +2145,12 @@ def test_triu_permute_dims():

shape = (2, 3, 4, 5)
X = dpt.permute_dims(
dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
),
dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape),
(3, 2, 1, 0),
)
Y = dpt.triu(X)
Xnp = np.transpose(
np.arange(np.prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
np.arange(prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
)
Ynp = np.triu(Xnp)
assert Y.dtype == Ynp.dtype
Expand Down Expand Up @@ -2189,12 +2188,12 @@ def test_triu_order_k(order, k):

shape = (3, 3)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
dpt.arange(prod(shape), dtype="int", sycl_queue=q),
shape,
order=order,
)
Y = dpt.triu(X, k=k)
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order)
Ynp = np.triu(Xnp, k=k)
assert Y.dtype == Ynp.dtype
assert X.flags == Y.flags
Expand All @@ -2210,12 +2209,12 @@ def test_tril_order_k(order, k):
pytest.skip("Queue could not be created")
shape = (3, 3)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
dpt.arange(prod(shape), dtype="int", sycl_queue=q),
shape,
order=order,
)
Y = dpt.tril(X, k=k)
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order)
Ynp = np.tril(Xnp, k=k)
assert Y.dtype == Ynp.dtype
assert X.flags == Y.flags
Expand Down

0 comments on commit b9242f4

Please sign in to comment.