Skip to content

Commit

Permalink
Make sure ext_oneapi_get_default_context doesn't broke runtime on w…
Browse files Browse the repository at this point in the history
…indows (#2742)

Part of #2478 (to reduce diff)

These are quite stable changes, we can merge it without CI on Windows.
@gshimansky if you don't mind.

---------

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Nov 19, 2024
1 parent cc1d4c5 commit 29d27d7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
23 changes: 21 additions & 2 deletions third_party/intel/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,26 @@ struct BuildFlags {
}
};

sycl::context get_default_context(const sycl::device &sycl_device) {
const auto &platform = sycl_device.get_platform();
#ifdef WIN32
sycl::context ctx;
try {
ctx = platform.ext_oneapi_get_default_context();
} catch (const std::runtime_error &ex) {
// This exception is thrown on Windows because
// ext_oneapi_get_default_context is not implemented. But it can be safely
// ignored it seems.
#if _DEBUG
std::cout << "ERROR: " << ex.what() << std::endl;
#endif
}
return ctx;
#else
return platform.ext_oneapi_get_default_context();
#endif
}

static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name, *build_flags_ptr;
int shared;
Expand Down Expand Up @@ -194,8 +214,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
const size_t binary_size = PyBytes_Size(py_bytes);

uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
const auto ctx =
sycl_device.get_platform().ext_oneapi_get_default_context();
const auto &ctx = get_default_context(sycl_device);
const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
Expand Down
23 changes: 22 additions & 1 deletion utils/SPIRVRunner/SPIRVRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
return std::get<0>(tuple);
}

sycl::context get_default_context(const sycl::device &sycl_device) {
const auto &platform = sycl_device.get_platform();
#ifdef WIN32
sycl::context ctx;
try {
ctx = platform.ext_oneapi_get_default_context();
} catch (const std::runtime_error &ex) {
// This exception is thrown on Windows because
// ext_oneapi_get_default_context is not implemented. But it can be safely
// ignored it seems.
#if _DEBUG
std::cout << "ERROR: " << ex.what() << std::endl;
#endif
}
return ctx;
#else
return platform.ext_oneapi_get_default_context();
#endif
}

/** SYCL Functions **/
std::tuple<sycl::kernel_bundle<sycl::bundle_state::executable>, sycl::kernel,
int32_t, int32_t>
Expand All @@ -138,7 +158,8 @@ loadBinary(const std::string &kernel_name, const std::string &build_flags,
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[deviceId];
const sycl::device sycl_device = sycl_l0_device_pair.first;

const auto ctx = sycl_device.get_platform().ext_oneapi_get_default_context();
const auto &ctx = get_default_context(sycl_device);

const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
Expand Down

0 comments on commit 29d27d7

Please sign in to comment.