diff --git a/python_bindings/test/correctness/multi_method_module_test.py b/python_bindings/test/correctness/multi_method_module_test.py index f7596279bd23..ea4871a8ff43 100644 --- a/python_bindings/test/correctness/multi_method_module_test.py +++ b/python_bindings/test/correctness/multi_method_module_test.py @@ -31,7 +31,21 @@ def test_user_context(): multi_method_module.user_context(None, ord('q'), output) assert output == bytearray("qqqq", "ascii") +def test_aot_call_failure_throws_exception(): + buffer_input = np.zeros([2, 2], dtype=np.uint8) + func_input = np.zeros([2, 2], dtype=np.float32) # wrong type + float_arg = 3.5 + simple_output = np.zeros([2, 2], dtype=np.float32) + + try: + multi_method_module.simple(buffer_input, func_input, float_arg, simple_output) + except RuntimeError as e: + assert 'Halide Runtime Error: -3 (Input buffer func_input has type uint8 but type of the buffer passed in is float32)' in str(e), str(e) + else: + assert False, 'Did not see expected exception, saw: ' + str(e) if __name__ == "__main__": test_simple() test_user_context() + test_aot_call_failure_throws_exception() + diff --git a/src/PythonExtensionGen.cpp b/src/PythonExtensionGen.cpp index 92ee9d441c4b..3725a1bb4502 100644 --- a/src/PythonExtensionGen.cpp +++ b/src/PythonExtensionGen.cpp @@ -104,6 +104,7 @@ PythonExtensionGen::PythonExtensionGen(std::ostream &dest) } void PythonExtensionGen::compile(const Module &module) { + dest << "#include \n"; dest << "#include \n"; dest << "#include \"HalideRuntime.h\"\n\n"; @@ -269,6 +270,12 @@ namespace Halide::PythonExtensions { #undef X } // namespace Halide::PythonExtensions +#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS +namespace Halide::PythonRuntime { + thread_local std::string current_error; +} // namespace Halide::PythonRuntime +#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS + namespace { #define _HALIDE_STRINGIFY(x) #x @@ -295,17 +302,38 @@ PyModuleDef _moduledef = { nullptr, // free }; +#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS +void _module_halide_error(void *user_context, const char *msg) { + using Halide::PythonRuntime::current_error; + if (current_error.empty()) { + // fprintf(stderr, "Setting current_error=(%s)\n", msg); + current_error = msg; + } else { + // fprintf(stderr, "Warning: error (%s) ignored because current_error=(%s)\n", msg, current_error.c_str()); + } +} + +void _module_halide_print(void *user_context, const char *msg) { + PySys_FormatStdout("%s", msg); +} +#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS + } // namespace extern "C" { HALIDE_EXPORT_SYMBOL PyObject *_HALIDE_EXPAND_AND_CONCAT(PyInit_, HALIDE_PYTHON_EXTENSION_MODULE)() { - return PyModule_Create(&_moduledef); + PyObject *m = PyModule_Create(&_moduledef); + #ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS + halide_set_error_handler(_module_halide_error); + halide_set_custom_print(_module_halide_print); + #endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS + return m; } } // extern "C" -#endif //HALIDE_PYTHON_EXTENSION_OMIT_MODULE_DEFINITION +#endif // HALIDE_PYTHON_EXTENSION_OMIT_MODULE_DEFINITION )INLINE_CODE"; } @@ -323,9 +351,14 @@ void PythonExtensionGen::compile(const LoweredFunc &f) { dest << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_FUNCTION_DEFINITIONS\n"; dest << "\n"; + dest << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n"; + dest << "namespace Halide::PythonRuntime {\n"; + dest << "extern thread_local std::string current_error;\n"; + dest << "} // namespace Halide::PythonRuntime\n"; + dest << "#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n"; + dest << "\n"; dest << "namespace Halide::PythonExtensions {\n"; dest << "\n"; - dest << "namespace {\n"; dest << "\n"; dest << indent << "const char* const " << basename << "_kwlist[] = {\n"; @@ -369,12 +402,12 @@ void PythonExtensionGen::compile(const LoweredFunc &f) { dest << print_type(&arg).first; } dest << "\", (char**)" << basename << "_kwlist\n"; + indent.indent += 2; for (size_t i = 0; i < args.size(); i++) { - indent.indent += 2; dest << indent << ", &py_" << arg_names[i] << "\n"; - indent.indent -= 2; } - dest << ")) {\n"; + indent.indent -= 2; + dest << indent << ")) {\n"; indent.indent += 2; dest << indent << "PyErr_Format(PyExc_ValueError, \"Internal error\");\n"; dest << indent << "return nullptr;\n"; @@ -430,7 +463,13 @@ void PythonExtensionGen::compile(const LoweredFunc &f) { } dest << indent << "if (result != 0) {\n"; indent.indent += 2; + dest << indent << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n"; + dest << indent << "std::string take;\n"; + dest << indent << "std::swap(take, Halide::PythonRuntime::current_error);\n"; + dest << indent << "PyErr_Format(PyExc_RuntimeError, \"Halide Runtime Error: %d (%s)\", result, take.c_str());\n"; + dest << indent << "#else\n"; dest << indent << "PyErr_Format(PyExc_ValueError, \"Halide error %d\", result);\n"; + dest << indent << "#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n"; dest << indent << "return nullptr;\n"; indent.indent -= 2; dest << indent << "}\n";