Skip to content

Commit

Permalink
Corrected order='K' support in astype
Browse files Browse the repository at this point in the history
Array API tests pointed out an error in implementation of order='K'
in dpctl.tensor.astype. Moved _empty_like_orderK and fried from
_type_utils to _copy_utils and used it to implement astype.

Modified import statement in _elementwise_common where _empty_like*
are used.
  • Loading branch information
oleksandr-pavlyk committed Aug 6, 2023
1 parent 07faf2b commit fd54583
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 113 deletions.
114 changes: 97 additions & 17 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import operator

import numpy as np
Expand Down Expand Up @@ -361,6 +362,96 @@ def copy(usm_ary, order="K"):
return R


def _empty_like_orderK(X, dt, usm_type=None, dev=None):
"""Returns empty array like `x`, using order='K'
For an array `x` that was obtained by permutation of a contiguous
array the returned array will have the same shape and the same
strides as `x`.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
if usm_type is None:
usm_type = X.usm_type
if dev is None:
dev = X.device
fl = X.flags
if fl["C"] or X.size <= 1:
return dpt.empty_like(
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
elif fl["F"]:
return dpt.empty_like(
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st = list(X.strides)
perm = sorted(
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
)
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
st_sorted = [st[i] for i in perm]
sh = X.shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if min(st_sorted) < 0:
sl = tuple(
slice(None, None, -1)
if st_sorted[i] < 0
else slice(None, None, None)
for i in range(X.ndim)
)
R = R[sl]
return dpt.permute_dims(R, inv_perm)


def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
if not isinstance(X1, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
if not isinstance(X2, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
nd1 = X1.ndim
nd2 = X2.ndim
if nd1 > nd2 and X1.shape == res_shape:
return _empty_like_orderK(X1, dt, usm_type, dev)
elif nd1 < nd2 and X2.shape == res_shape:
return _empty_like_orderK(X2, dt, usm_type, dev)
fl1 = X1.flags
fl2 = X2.flags
if fl1["C"] or fl2["C"]:
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
if fl1["F"] and fl2["F"]:
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st1 = list(X1.strides)
st2 = list(X2.strides)
max_ndim = max(nd1, nd2)
st1 += [0] * (max_ndim - len(st1))
st2 += [0] * (max_ndim - len(st2))
perm = sorted(
range(max_ndim),
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
reverse=True,
)
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
st1_sorted = [st1[i] for i in perm]
st2_sorted = [st2[i] for i in perm]
sh = res_shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if max(min(st1_sorted), min(st2_sorted)) < 0:
sl = tuple(
slice(None, None, -1)
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
else slice(None, None, None)
for i in range(nd1)
)
R = R[sl]
return dpt.permute_dims(R, inv_perm)


def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
""" astype(array, new_dtype, order="K", casting="unsafe", \
copy=True)
Expand Down Expand Up @@ -432,26 +523,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
"Unrecognized value of the order keyword. "
"Recognized values are 'A', 'C', 'F', or 'K'"
)
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=target_dtype,
buffer=usm_ary.usm_type,
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
if order == "K" and (not c_contig and not f_contig):
original_strides = usm_ary.strides
ind = sorted(
range(usm_ary.ndim),
key=lambda i: abs(original_strides[i]),
reverse=True,
)
new_strides = tuple(R.strides[ind[i]] for i in ind)
if order == "K":
R = _empty_like_orderK(usm_ary, target_dtype)
else:
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=target_dtype,
buffer=R.usm_data,
strides=new_strides,
buffer=usm_ary.usm_type,
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
return R
Expand Down
3 changes: 1 addition & 2 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
from dpctl.utils import ExecutionPlacementError

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
_acceptance_fn_default,
_empty_like_orderK,
_empty_like_pair_orderK,
_find_buf_dtype,
_find_buf_dtype2,
_find_inplace_dtype,
Expand Down
94 changes: 0 additions & 94 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import builtins

import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti

Expand Down Expand Up @@ -116,96 +114,6 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool):
return can_cast_v


def _empty_like_orderK(X, dt, usm_type=None, dev=None):
"""Returns empty array like `x`, using order='K'
For an array `x` that was obtained by permutation of a contiguous
array the returned array will have the same shape and the same
strides as `x`.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
if usm_type is None:
usm_type = X.usm_type
if dev is None:
dev = X.device
fl = X.flags
if fl["C"] or X.size <= 1:
return dpt.empty_like(
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
elif fl["F"]:
return dpt.empty_like(
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st = list(X.strides)
perm = sorted(
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
)
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
st_sorted = [st[i] for i in perm]
sh = X.shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if min(st_sorted) < 0:
sl = tuple(
slice(None, None, -1)
if st_sorted[i] < 0
else slice(None, None, None)
for i in range(X.ndim)
)
R = R[sl]
return dpt.permute_dims(R, inv_perm)


def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
if not isinstance(X1, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
if not isinstance(X2, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
nd1 = X1.ndim
nd2 = X2.ndim
if nd1 > nd2 and X1.shape == res_shape:
return _empty_like_orderK(X1, dt, usm_type, dev)
elif nd1 < nd2 and X2.shape == res_shape:
return _empty_like_orderK(X2, dt, usm_type, dev)
fl1 = X1.flags
fl2 = X2.flags
if fl1["C"] or fl2["C"]:
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
if fl1["F"] and fl2["F"]:
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st1 = list(X1.strides)
st2 = list(X2.strides)
max_ndim = max(nd1, nd2)
st1 += [0] * (max_ndim - len(st1))
st2 += [0] * (max_ndim - len(st2))
perm = sorted(
range(max_ndim),
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
reverse=True,
)
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
st1_sorted = [st1[i] for i in perm]
st2_sorted = [st2[i] for i in perm]
sh = res_shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if max(min(st1_sorted), min(st2_sorted)) < 0:
sl = tuple(
slice(None, None, -1)
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
else slice(None, None, None)
for i in range(nd1)
)
R = R[sl]
return dpt.permute_dims(R, inv_perm)


def _to_device_supported_dtype(dt, dev):
has_fp16 = dev.has_aspect_fp16
has_fp64 = dev.has_aspect_fp64
Expand Down Expand Up @@ -339,8 +247,6 @@ def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
"_find_buf_dtype",
"_find_buf_dtype2",
"_find_inplace_dtype",
"_empty_like_orderK",
"_empty_like_pair_orderK",
"_to_device_supported_dtype",
"_acceptance_fn_default",
"_acceptance_fn_divide",
Expand Down

0 comments on commit fd54583

Please sign in to comment.