Skip to content

Commit

Permalink
[FFI][BUGFIX] Grab GIL when check env signals
Browse files Browse the repository at this point in the history
This PR updates the CheckSignals function to grab GIL.
This is needed because we now explicitly release gil when calling
any C functions. GIL will need to be obtained otherwise we will
run into segfault when checking the signal.

The update now enables us to run ctrl + C in long running C functions.
  • Loading branch information
tqchen committed Sep 25, 2024
1 parent 30b7b1c commit 83f08a0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 25 deletions.
16 changes: 11 additions & 5 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ cdef inline void* c_handle(object handle):
# python env API
cdef extern from "Python.h":
int PyErr_CheckSignals()
void* PyGILState_Ensure()
void PyGILState_Release(void*)
void Py_IncRef(void*)
void Py_DecRef(void*)

cdef extern from "tvm/runtime/c_backend_api.h":
int TVMBackendRegisterEnvCAPI(const char* name, void* ptr)
Expand All @@ -210,11 +214,13 @@ cdef _init_env_api():
# so backend can call tvm::runtime::EnvCheckSignals to check
# signal when executing a long running function.
#
# This feature is only enabled in cython for now due to problems of calling
# these functions in ctypes.
#
# When the functions are not registered, the signals will be handled
# only when the FFI function returns.
# Also registers the gil state release and ensure as PyErr_CheckSignals
# function is called with gil released and we need to regrab the gil
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyErr_CheckSignals"), <void*>PyErr_CheckSignals))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_IncRef"), <void*>Py_IncRef))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_DecRef"), <void*>Py_DecRef))

_init_env_api()
16 changes: 0 additions & 16 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -376,19 +376,3 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object

# Py_INCREF and Py_DECREF are C macros, not function objects.
# Therefore, providing a wrapper function.
cdef void _py_incref_wrapper(void* py_object):
Py_INCREF(<object>py_object)
cdef void _py_decref_wrapper(void* py_object):
Py_DECREF(<object>py_object)

def _init_pythonapi_inc_def_ref():
register_func = TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)
register_func(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure)
register_func(c_str("PyGILState_Release"), <void*>PyGILState_Release)

_init_pythonapi_inc_def_ref()
12 changes: 8 additions & 4 deletions src/runtime/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ class EnvCAPIRegistry {
// implementation of tvm::runtime::EnvCheckSignals
void CheckSignals() {
// check python signal to see if there are exception raised
if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) {
// The error will let FFI know that the frontend environment
// already set an error.
throw EnvErrorAlreadySet("");
if (pyerr_check_signals != nullptr) {
// The C++ env comes without gil, so we need to grab gil here
WithGIL context(this);
if ((*pyerr_check_signals)() != 0) {
// The error will let FFI know that the frontend environment
// already set an error.
throw EnvErrorAlreadySet("");
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) {
std::this_thread::sleep_for(duration);
});

TVM_REGISTER_GLOBAL("testing.check_signals").set_body_typed([](double sleep_period) {
while (true) {
std::chrono::duration<int64_t, std::nano> duration(static_cast<int64_t>(sleep_period * 1e9));
std::this_thread::sleep_for(duration);
runtime::EnvCheckSignals();
}
});

TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant<String, IntImm> {
if (x % 2 == 0) {
return IntImm(DataType::Int(64), x / 2);
Expand Down

0 comments on commit 83f08a0

Please sign in to comment.