diff --git a/src/main/c/jni/org_jpy_PyLib.c b/src/main/c/jni/org_jpy_PyLib.c index 8f4e9e70..e0c1419d 100644 --- a/src/main/c/jni/org_jpy_PyLib.c +++ b/src/main/c/jni/org_jpy_PyLib.c @@ -2145,6 +2145,21 @@ JNIEXPORT jboolean JNICALL Java_org_jpy_PyLib_hasGil return result; } +/* + * Class: org_jpy_PyLib + * Method: ensureGil + * Signature: (Ljava/util/function/Supplier;)Ljava/lang/Object; + */ +JNIEXPORT jobject JNICALL Java_org_jpy_PyLib_ensureGil + (JNIEnv* jenv, jclass jLibClass, jobject supplier) +{ + jobject result; + JPy_BEGIN_GIL_STATE + result = (*jenv)->CallObjectMethod(jenv, supplier, JPy_Supplier_get_MID); + JPy_END_GIL_STATE + return result; +} + /* * Class: org_jpy_python_PyLib diff --git a/src/main/c/jni/org_jpy_PyLib.h b/src/main/c/jni/org_jpy_PyLib.h index 515b8029..d8498186 100644 --- a/src/main/c/jni/org_jpy_PyLib.h +++ b/src/main/c/jni/org_jpy_PyLib.h @@ -423,6 +423,14 @@ JNIEXPORT jboolean JNICALL Java_org_jpy_PyLib_hasAttribute JNIEXPORT jboolean JNICALL Java_org_jpy_PyLib_hasGil (JNIEnv *, jclass); +/* + * Class: org_jpy_PyLib + * Method: ensureGil + * Signature: (Ljava/util/function/Supplier;)Ljava/lang/Object; + */ +JNIEXPORT jobject JNICALL Java_org_jpy_PyLib_ensureGil + (JNIEnv *, jclass, jobject); + /* * Class: org_jpy_PyLib * Method: callAndReturnObject diff --git a/src/main/c/jpy_jmethod.c b/src/main/c/jpy_jmethod.c index f315758a..f77b5ebd 100644 --- a/src/main/c/jpy_jmethod.c +++ b/src/main/c/jpy_jmethod.c @@ -268,48 +268,80 @@ PyObject* JMethod_InvokeMethod(JNIEnv* jenv, JPy_JMethod* method, PyObject* pyAr JPy_DIAG_PRINT(JPy_DIAG_F_EXEC, "JMethod_InvokeMethod: calling static Java method %s#%s\n", declaringClass->javaName, JPy_AS_UTF8(method->name)); if (returnType == JPy_JVoid) { + Py_BEGIN_ALLOW_THREADS; (*jenv)->CallStaticVoidMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JVOID(); } else if (returnType == JPy_JBoolean) { - jboolean v = (*jenv)->CallStaticBooleanMethodA(jenv, classRef, method->mid, jArgs); + jboolean v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticBooleanMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JBOOLEAN(v); } else if (returnType == JPy_JChar) { - jchar v = (*jenv)->CallStaticCharMethodA(jenv, classRef, method->mid, jArgs); + jchar v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticCharMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JCHAR(v); } else if (returnType == JPy_JByte) { - jbyte v = (*jenv)->CallStaticByteMethodA(jenv, classRef, method->mid, jArgs); + jbyte v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticByteMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JBYTE(v); } else if (returnType == JPy_JShort) { - jshort v = (*jenv)->CallStaticShortMethodA(jenv, classRef, method->mid, jArgs); + jshort v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticShortMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JSHORT(v); } else if (returnType == JPy_JInt) { - jint v = (*jenv)->CallStaticIntMethodA(jenv, classRef, method->mid, jArgs); + jint v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticIntMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JINT(v); } else if (returnType == JPy_JLong) { - jlong v = (*jenv)->CallStaticLongMethodA(jenv, classRef, method->mid, jArgs); + jlong v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticLongMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JLONG(v); } else if (returnType == JPy_JFloat) { - jfloat v = (*jenv)->CallStaticFloatMethodA(jenv, classRef, method->mid, jArgs); + jfloat v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticFloatMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JFLOAT(v); } else if (returnType == JPy_JDouble) { - jdouble v = (*jenv)->CallStaticDoubleMethodA(jenv, classRef, method->mid, jArgs); + jdouble v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticDoubleMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JDOUBLE(v); } else if (returnType == JPy_JString) { - jstring v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs); + jstring v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FromJString(jenv, v); JPy_DELETE_LOCAL_REF(v); } else { - jobject v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs); + jobject v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JMethod_FromJObject(jenv, method, pyArgs, jArgs, 0, returnType, v); JPy_DELETE_LOCAL_REF(v); @@ -326,48 +358,80 @@ PyObject* JMethod_InvokeMethod(JNIEnv* jenv, JPy_JMethod* method, PyObject* pyAr objectRef = ((JPy_JObj*) self)->objectRef; if (returnType == JPy_JVoid) { + Py_BEGIN_ALLOW_THREADS; (*jenv)->CallVoidMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JVOID(); } else if (returnType == JPy_JBoolean) { - jboolean v = (*jenv)->CallBooleanMethodA(jenv, objectRef, method->mid, jArgs); + jboolean v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallBooleanMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JBOOLEAN(v); } else if (returnType == JPy_JChar) { - jchar v = (*jenv)->CallCharMethodA(jenv, objectRef, method->mid, jArgs); + jchar v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallCharMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JCHAR(v); } else if (returnType == JPy_JByte) { - jbyte v = (*jenv)->CallByteMethodA(jenv, objectRef, method->mid, jArgs); + jbyte v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallByteMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JBYTE(v); } else if (returnType == JPy_JShort) { - jshort v = (*jenv)->CallShortMethodA(jenv, objectRef, method->mid, jArgs); + jshort v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallShortMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JSHORT(v); } else if (returnType == JPy_JInt) { - jint v = (*jenv)->CallIntMethodA(jenv, objectRef, method->mid, jArgs); + jint v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallIntMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JINT(v); } else if (returnType == JPy_JLong) { - jlong v = (*jenv)->CallLongMethodA(jenv, objectRef, method->mid, jArgs); + jlong v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallLongMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JLONG(v); } else if (returnType == JPy_JFloat) { - jfloat v = (*jenv)->CallFloatMethodA(jenv, objectRef, method->mid, jArgs); + jfloat v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallFloatMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JFLOAT(v); } else if (returnType == JPy_JDouble) { - jdouble v = (*jenv)->CallDoubleMethodA(jenv, objectRef, method->mid, jArgs); + jdouble v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallDoubleMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FROM_JDOUBLE(v); } else if (returnType == JPy_JString) { - jstring v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs); + jstring v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JPy_FromJString(jenv, v); JPy_DELETE_LOCAL_REF(v); } else { - jobject v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs); + jobject v; + Py_BEGIN_ALLOW_THREADS; + v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs); + Py_END_ALLOW_THREADS; JPy_ON_JAVA_EXCEPTION_GOTO(error); returnValue = JMethod_FromJObject(jenv, method, pyArgs, jArgs, 1, returnType, v); JPy_DELETE_LOCAL_REF(v); diff --git a/src/main/c/jpy_jtype.c b/src/main/c/jpy_jtype.c index 8848565c..48e2a0bd 100644 --- a/src/main/c/jpy_jtype.c +++ b/src/main/c/jpy_jtype.c @@ -377,7 +377,9 @@ int JType_PythonToJavaConversionError(JPy_JType* type, PyObject* pyArg) int JType_CreateJavaObject(JNIEnv* jenv, JPy_JType* type, PyObject* pyArg, jclass classRef, jmethodID initMID, jvalue value, jobject* objectRef) { + Py_BEGIN_ALLOW_THREADS; *objectRef = (*jenv)->NewObjectA(jenv, classRef, initMID, &value); + Py_END_ALLOW_THREADS; if (*objectRef == NULL) { PyErr_NoMemory(); return -1; @@ -388,7 +390,9 @@ int JType_CreateJavaObject(JNIEnv* jenv, JPy_JType* type, PyObject* pyArg, jclas int JType_CreateJavaObject_2(JNIEnv* jenv, JPy_JType* type, PyObject* pyArg, jclass classRef, jmethodID initMID, jvalue value1, jvalue value2, jobject* objectRef) { + Py_BEGIN_ALLOW_THREADS; *objectRef = (*jenv)->NewObject(jenv, classRef, initMID, value1, value2); + Py_END_ALLOW_THREADS; if (*objectRef == NULL) { PyErr_NoMemory(); return -1; diff --git a/src/main/c/jpy_module.c b/src/main/c/jpy_module.c index 24a714de..0626f26c 100644 --- a/src/main/c/jpy_module.c +++ b/src/main/c/jpy_module.c @@ -244,6 +244,10 @@ jmethodID JPy_Throwable_getCause_MID = NULL; // stack trace element jclass JPy_StackTraceElement_JClass = NULL; +// java.util.function.Supplier +jclass JPy_Supplier_JClass = NULL; +jmethodID JPy_Supplier_get_MID = NULL; + // }}} @@ -954,6 +958,9 @@ int JPy_InitGlobalVars(JNIEnv* jenv) DEFINE_METHOD(JPy_Throwable_getCause_MID, JPy_Throwable_JClass, "getCause", "()Ljava/lang/Throwable;"); DEFINE_METHOD(JPy_Throwable_getStackTrace_MID, JPy_Throwable_JClass, "getStackTrace", "()[Ljava/lang/StackTraceElement;"); + DEFINE_CLASS(JPy_Supplier_JClass, "java/util/function/Supplier"); + DEFINE_METHOD(JPy_Supplier_get_MID, JPy_Supplier_JClass, "get", "()Ljava/lang/Object;") + // JType_AddClassAttribute is actually called from within JType_GetType(), but not for // JPy_JObject and JPy_JClass for an obvious reason. So we do it now: JType_AddClassAttribute(jenv, JPy_JObject); diff --git a/src/main/c/jpy_module.h b/src/main/c/jpy_module.h index ce96959e..fbcaee28 100644 --- a/src/main/c/jpy_module.h +++ b/src/main/c/jpy_module.h @@ -256,6 +256,9 @@ extern jmethodID JPy_PyObject_Init_MID; extern jclass JPy_PyDictWrapper_JClass; extern jmethodID JPy_PyDictWrapper_GetPointer_MID; +extern jclass JPy_Supplier_JClass; +extern jmethodID JPy_Supplier_get_MID; + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/main/java/org/jpy/PyLib.java b/src/main/java/org/jpy/PyLib.java index ce1e4dd7..2f07a064 100644 --- a/src/main/java/org/jpy/PyLib.java +++ b/src/main/java/org/jpy/PyLib.java @@ -22,6 +22,7 @@ import java.io.File; import java.io.FileNotFoundException; import java.util.ArrayList; +import java.util.function.Supplier; import static org.jpy.PyLibConfig.*; @@ -431,6 +432,8 @@ public static void stopPython() { public static native boolean hasGil(); + public static native T ensureGil(Supplier runnable); + /** * Calls a Python callable and returns the resulting Python object. *

diff --git a/src/main/java/org/jpy/PyObjectReferences.java b/src/main/java/org/jpy/PyObjectReferences.java index fe35d1f1..1196cf14 100644 --- a/src/main/java/org/jpy/PyObjectReferences.java +++ b/src/main/java/org/jpy/PyObjectReferences.java @@ -81,33 +81,30 @@ public int cleanupOnlyUseFromGIL() { } private int cleanupOnlyUseFromGIL(long[] buffer) { - if (!PyLib.hasGil()) { - throw new IllegalStateException( - "We should only be calling PyObjectReferences.cleanupOnlyUseFromGIL if we have the GIL!"); - } - - int index = 0; - while (index < buffer.length) { - final Reference reference = referenceQueue.poll(); - if (reference == null) { - break; + return PyLib.ensureGil(() -> { + int index = 0; + while (index < buffer.length) { + final Reference reference = referenceQueue.poll(); + if (reference == null) { + break; + } + index = appendIfNotClosed(buffer, index, reference); + } + if (index == 0) { + return 0; } - index = appendIfNotClosed(buffer, index, reference); - } - if (index == 0) { - return 0; - } - // We really really really want to make sure we *already* have the GIL lock at this point in - // time. Otherwise, we block here until the GIL is available for us, and stall all cleanup - // related to our PyObjects. + // We really really really want to make sure we *already* have the GIL lock at this point in + // time. Otherwise, we block here until the GIL is available for us, and stall all cleanup + // related to our PyObjects. - if (index == 1) { - PyLib.decRef(buffer[0]); - return 1; - } - PyLib.decRefs(buffer, index); - return index; + if (index == 1) { + PyLib.decRef(buffer[0]); + return 1; + } + PyLib.decRefs(buffer, index); + return index; + }); } private int appendIfNotClosed(long[] buffer, int index, Reference reference) { diff --git a/src/test/java/org/jpy/PyLibTest.java b/src/test/java/org/jpy/PyLibTest.java index b50171be..bac187a8 100644 --- a/src/test/java/org/jpy/PyLibTest.java +++ b/src/test/java/org/jpy/PyLibTest.java @@ -283,4 +283,26 @@ public void decRefs() { PyLib.decRefs(new long[] { pyObject1, pyObject2, 0, 0 }, 2); } + + @Test + public void testEnsureGIL() { + assertFalse(PyLib.hasGil()); + boolean[] lambdaSuccessfullyRan = {false}; + Integer intResult = PyLib.ensureGil(() -> { + assertTrue(PyLib.hasGil()); + lambdaSuccessfullyRan[0] = true; + return 123; + }); + assertEquals((Integer) 123, intResult); + assertTrue(lambdaSuccessfullyRan[0]); + + try { + Object result = PyLib.ensureGil(() -> { + throw new IllegalStateException("Error from inside GIL block"); + }); + fail("Exception expected"); + } catch (IllegalStateException expectedException) { + assertEquals("Error from inside GIL block", expectedException.getMessage()); + }//let anything else rethrow as a failure + } }