Skip to content

Commit

Permalink
pybind11 type caster for sycl::half (#1655)
Browse files Browse the repository at this point in the history
* Implements type caster for sycl::half and removes unboxing_helper.hpp

`py::cast<sycl_half>` being available makes PythonObjectUnboxer redundant

* Apply changes per review to sycl::half caster

* Remove unnecessary check for py_err and unneeded PyErr_Clear call per PR feedback
  • Loading branch information
ndgrigorian authored Apr 27, 2024
1 parent aaf444e commit b12583a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 81 deletions.
47 changes: 47 additions & 0 deletions dpctl/apis/include/dpctl4pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,53 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
_("dpctl.program.SyclProgram"));
};

/* This type caster associates
* ``sycl::half`` C++ class with Python :class:`float` for the purposes
* of generation of Python bindings by pybind11.
*/
template <> struct type_caster<sycl::half>
{
public:
bool load(handle src, bool convert)
{
double py_value;

if (!src) {
return false;
}

PyObject *source = src.ptr();

if (convert || PyFloat_Check(source)) {
py_value = PyFloat_AsDouble(source);
}
else {
return false;
}

bool py_err = (py_value == double(-1)) && PyErr_Occurred();

if (py_err) {
PyErr_Clear();
if (convert && (PyNumber_Check(source) != 0)) {
auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
return load(tmp, false);
}
return false;
}
value = static_cast<sycl::half>(py_value);
return true;
}

static handle cast(sycl::half src, return_value_policy, handle)
{
return PyFloat_FromDouble(static_cast<double>(src));
}

PYBIND11_TYPE_CASTER(sycl::half, _("float"));
};

} // namespace detail
} // namespace pybind11

Expand Down
10 changes: 1 addition & 9 deletions dpctl/tensor/libtensor/source/full_ctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "utils/type_utils.hpp"

#include "full_ctor.hpp"
#include "unboxing_helper.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
Expand Down Expand Up @@ -79,14 +78,7 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
char *dst_p,
const std::vector<sycl::event> &depends)
{
dstTy fill_v;

PythonObjectUnboxer<dstTy> unboxer{};
try {
fill_v = unboxer(py_value);
} catch (const py::error_already_set &e) {
throw;
}
dstTy fill_v = py::cast<dstTy>(py_value);

using dpctl::tensor::kernels::constructors::full_contig_impl;

Expand Down
23 changes: 4 additions & 19 deletions dpctl/tensor/libtensor/source/linear_sequences.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "utils/type_utils.hpp"

#include "linear_sequences.hpp"
#include "unboxing_helper.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
Expand Down Expand Up @@ -86,16 +85,8 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty start_v;
Ty step_v;

const auto &unboxer = PythonObjectUnboxer<Ty>{};
try {
start_v = unboxer(start);
step_v = unboxer(step);
} catch (const py::error_already_set &e) {
throw;
}
Ty start_v = py::cast<Ty>(start);
Ty step_v = py::cast<Ty>(step);

using dpctl::tensor::kernels::constructors::lin_space_step_impl;

Expand Down Expand Up @@ -143,14 +134,8 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty start_v, end_v;
const auto &unboxer = PythonObjectUnboxer<Ty>{};
try {
start_v = unboxer(start);
end_v = unboxer(end);
} catch (const py::error_already_set &e) {
throw;
}
Ty start_v = py::cast<Ty>(start);
Ty end_v = py::cast<Ty>(end);

using dpctl::tensor::kernels::constructors::lin_space_affine_impl;

Expand Down
53 changes: 0 additions & 53 deletions dpctl/tensor/libtensor/source/unboxing_helper.hpp

This file was deleted.

0 comments on commit b12583a

Please sign in to comment.