Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate changes in pr 1763 from changes in pr 1760 #1772

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
elif dt_kind == "i":
return dpt.dtype(ti.default_device_int_type(sycl_dev))
elif dt_kind == "u":
return dpt.dtype(ti.default_device_int_type(sycl_dev).upper())
return dpt.dtype(ti.default_device_uint_type(sycl_dev))
elif dt_kind == "f":
return dpt.dtype(ti.default_device_fp_type(sycl_dev))
elif dt_kind == "c":
Expand Down Expand Up @@ -790,7 +790,7 @@ def _default_accumulation_dtype(inp_dt, q):
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_dt = dpt.dtype(ti.default_device_uint_type(q))
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
Expand Down
49 changes: 46 additions & 3 deletions dpctl/tensor/libtensor/source/device_support_queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,48 @@ std::string _default_device_fp_type(const sycl::device &d)
}
}

int get_numpy_major_version()
{
namespace py = pybind11;

py::module_ numpy = py::module_::import("numpy");
py::str version_string = numpy.attr("__version__");
py::module_ numpy_lib = py::module_::import("numpy.lib");

py::object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
int major_version = numpy_version.attr("major").cast<int>();

return major_version;
}

std::string _default_device_int_type(const sycl::device &)
{
return "l"; // code for numpy.dtype('long') to be consistent
// with NumPy's default integer type across
// platforms.
const int np_ver = get_numpy_major_version();

if (np_ver >= 2) {
return "i8";
}
else {
// code for numpy.dtype('long') to be consistent
// with NumPy's default integer type across
// platforms.
return "l";
}
}

std::string _default_device_uint_type(const sycl::device &)
{
const int np_ver = get_numpy_major_version();

if (np_ver >= 2) {
return "u8";
}
else {
// code for numpy.dtype('long') to be consistent
// with NumPy's default integer type across
// platforms.
return "L";
}
}

std::string _default_device_complex_type(const sycl::device &d)
Expand Down Expand Up @@ -108,6 +145,12 @@ std::string default_device_int_type(const py::object &arg)
return _default_device_int_type(d);
}

std::string default_device_uint_type(const py::object &arg)
{
const sycl::device &d = _extract_device(arg);
return _default_device_uint_type(d);
}

std::string default_device_bool_type(const py::object &arg)
{
const sycl::device &d = _extract_device(arg);
Expand Down
1 change: 1 addition & 0 deletions dpctl/tensor/libtensor/source/device_support_queries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace py_internal

extern std::string default_device_fp_type(const py::object &);
extern std::string default_device_int_type(const py::object &);
extern std::string default_device_uint_type(const py::object &);
extern std::string default_device_bool_type(const py::object &);
extern std::string default_device_complex_type(const py::object &);
extern std::string default_device_index_type(const py::object &);
Expand Down
8 changes: 7 additions & 1 deletion dpctl/tensor/libtensor/source/tensor_ctors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,13 @@ PYBIND11_MODULE(_tensor_impl, m)

m.def("default_device_int_type",
dpctl::tensor::py_internal::default_device_int_type,
"Gives default integer type supported by device.", py::arg("dev"));
"Gives default signed integer type supported by device.",
py::arg("dev"));

m.def("default_device_uint_type",
dpctl::tensor::py_internal::default_device_uint_type,
"Gives default unsigned integer type supported by device.",
py::arg("dev"));

m.def("default_device_bool_type",
dpctl::tensor::py_internal::default_device_bool_type,
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,9 +1010,9 @@ def test_pyx_capi_check_constants():
assert uint_typenum == dpt.dtype(np.uintc).num

long_typenum = _pyx_capi_int(X, "UAR_LONG")
assert long_typenum == dpt.dtype(np.int_).num
assert long_typenum == dpt.dtype("l").num
ulong_typenum = _pyx_capi_int(X, "UAR_ULONG")
assert ulong_typenum == dpt.dtype(np.uint).num
assert ulong_typenum == dpt.dtype("L").num

longlong_typenum = _pyx_capi_int(X, "UAR_LONGLONG")
assert longlong_typenum == dpt.dtype(np.longlong).num
Expand Down
Loading