-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
usm_ndarray.to_device(dev, stream=queue) support #1331
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1331/index.html |
Array API standard conformance tests for dpctl=0.14.6dev0=py310h7bf5fec_121 ran successfully. |
Array API standard conformance tests for dpctl=0.14.6dev0=py310h7bf5fec_122 ran successfully. |
This commit short-circuited broadcastability of shapes for zero-size rhs in __setitem__. Refix array API test failure without introducing this regression and reuse _manipulation_functions._broadcast_strides as suggested by @ndgrigorian
b9f5fc7
to
8b03642
Compare
Array API standard conformance tests for dpctl=0.14.6dev1=py310h7bf5fec_35 ran successfully. |
The following example is still not working: In [1]: import dpnp, numpy, dpctl, dpctl.tensor as dpt
In [2]: a = dpt.ones((1, 2, 1, 4))
In [3]: b = dpt.empty((2, 3, 4))
In [4]: b[...] = a
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 b[...] = a
File dpctl/tensor/_usmarray.pyx:1178, in dpctl.tensor._usmarray.usm_ndarray.__setitem__()
File ~/miniconda3/envs/dpnp_dev/lib/python3.9/site-packages/dpctl/tensor/_copy_utils.py:307, in _copy_from_usm_ndarray_to_usm_ndarray(dst, src)
305 else:
306 src_same_shape = src
--> 307 src_same_shape.shape = common_shape
309 _copy_same_shape(dst, src_same_shape)
File dpctl/tensor/_usmarray.pyx:576, in dpctl.tensor._usmarray.usm_ndarray.shape.__set__()
TypeError: Can not reshape array of size 8 into (2, 3, 4)
In [5]: dpctl.__version__
Out[5]: '0.14.6dev1+35.g5379d9342'
In [6]: conda list dpctl
# packages in environment at /home/xantvol/miniconda3/envs/dpnp_dev:
#
# Name Version Build Channel
dpctl 0.14.6dev1 py39h7bf5fec_35 file:///mnt/c/Users/antonvol/Downloads/dpctl Linux Python 3.9 |
Also added comment to explain the logic of the code.
Array API standard conformance tests for dpctl=0.14.6dev1=py310h7bf5fec_39 ran successfully. |
dpnp tests works fine with the PR, thank you @oleksandr-pavlyk ! |
``` In [1]: import dpctl.tensor as dpt In [2]: a = dpt.ones((2, 3)) ...: dpt.reshape(a, (1, 6, 1)).flags Out[2]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [3]: a = dpt.ones((2, 3), order='F') ...: dpt.reshape(a, (1, 6, 1), order='F').flags Out[3]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [4]: a = dpt.ones((2, 3, 4)) ...: dpt.sum(a, axis=(1, 2), keepdims=True).flags Out[4]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True ```
A non-empty array which is effectively 1D (only one dimension has size greater than one) should be marked as both C- and F- contiguous. ``` In [1]: import dpctl.tensor as dpt In [2]: a = dpt.ones((2, 3)) ...: dpt.reshape(a, (1, 6, 1)).flags Out[2]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [3]: a = dpt.ones((2, 3), order='F') ...: dpt.reshape(a, (1, 6, 1), order='F').flags Out[3]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [4]: a = dpt.ones((2, 3, 4)) ...: dpt.sum(a, axis=(1, 2), keepdims=True).flags Out[4]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True ```
Array API standard conformance tests for dpctl=0.14.6dev1=py310h7bf5fec_44 ran successfully. |
ed4e003
to
706d80f
Compare
I've checked the latest changes, they look good. The problem reported in #1334 is gone. No new issue is observed. |
Array API standard conformance tests for dpctl=0.14.6dev1=py310h7bf5fec_43 ran successfully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked through the PR and don't see any glaring issues.
LGTM, thanks @oleksandr-pavlyk !
Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞 |
Array API standard conformance tests for dpctl=0.14.6dev1=py310h7bf5fec_53 ran successfully. |
This PR fixes two bugs: closes gh-1330, and fixes array_api_tests failure regarding
usm_ndarray.to_device
not supportingstream
keyword.