Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error-handling in Python Extensions #6986

Merged
merged 4 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python_bindings/test/correctness/multi_method_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,21 @@ def test_user_context():
multi_method_module.user_context(None, ord('q'), output)
assert output == bytearray("qqqq", "ascii")

def test_aot_call_failure_throws_exception():
buffer_input = np.zeros([2, 2], dtype=np.uint8)
func_input = np.zeros([2, 2], dtype=np.float32) # wrong type
float_arg = 3.5
simple_output = np.zeros([2, 2], dtype=np.float32)

try:
multi_method_module.simple(buffer_input, func_input, float_arg, simple_output)
except RuntimeError as e:
assert 'Halide Runtime Error: -3 (Input buffer func_input has type uint8 but type of the buffer passed in is float32)' in str(e), str(e)
else:
assert False, 'Did not see expected exception, saw: ' + str(e)

if __name__ == "__main__":
test_simple()
test_user_context()
test_aot_call_failure_throws_exception()

46 changes: 40 additions & 6 deletions src/PythonExtensionGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PythonExtensionGen::PythonExtensionGen(std::ostream &dest)
}

void PythonExtensionGen::compile(const Module &module) {
dest << "#include <string>\n";
dest << "#include <Python.h>\n";
dest << "#include \"HalideRuntime.h\"\n\n";

Expand Down Expand Up @@ -269,6 +270,10 @@ namespace Halide::PythonExtensions {
#undef X
} // namespace Halide::PythonExtensions

namespace Halide::PythonRuntime {
thread_local std::string current_error;
} // namespace Halide::PythonRuntime

namespace {

#define _HALIDE_STRINGIFY(x) #x
Expand All @@ -295,12 +300,36 @@ PyModuleDef _moduledef = {
nullptr, // free
};

void _module_halide_error(void *user_context, const char *msg) {
using Halide::PythonRuntime::current_error;
if (current_error.empty()) {
// fprintf(stderr, "Setting current_error=(%s)\n", msg);
current_error = msg;
} else {
// fprintf(stderr, "Warning: error (%s) ignored because current_error=(%s)\n", msg, current_error.c_str());
}
}

void _module_halide_print(void *user_context, const char *msg) {
PySys_FormatStdout("%s", msg);
}

} // namespace

extern "C" {

#ifdef HALIDE_PYTHON_EXTENSION_INCLUDE_RUNTIME_SUBMODULE
extern PyObject *_halide_runtime_submodule_impl(PyObject *module);
#endif

HALIDE_EXPORT_SYMBOL PyObject *_HALIDE_EXPAND_AND_CONCAT(PyInit_, HALIDE_PYTHON_EXTENSION_MODULE)() {
return PyModule_Create(&_moduledef);
PyObject *m = PyModule_Create(&_moduledef);
#ifdef HALIDE_PYTHON_EXTENSION_INCLUDE_RUNTIME_SUBMODULE
(void)_halide_runtime_submodule_impl(m);
#endif
halide_set_error_handler(_module_halide_error);
halide_set_custom_print(_module_halide_print);
return m;
}

} // extern "C"
Expand All @@ -323,9 +352,12 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {

dest << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_FUNCTION_DEFINITIONS\n";
dest << "\n";
dest << "namespace Halide::PythonRuntime {\n";
dest << "extern thread_local std::string current_error;\n";
dest << "} // namespace Halide::PythonRuntime\n";
dest << "\n";
dest << "namespace Halide::PythonExtensions {\n";
dest << "\n";

dest << "namespace {\n";
dest << "\n";
dest << indent << "const char* const " << basename << "_kwlist[] = {\n";
Expand Down Expand Up @@ -369,12 +401,12 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
dest << print_type(&arg).first;
}
dest << "\", (char**)" << basename << "_kwlist\n";
indent.indent += 2;
for (size_t i = 0; i < args.size(); i++) {
indent.indent += 2;
dest << indent << ", &py_" << arg_names[i] << "\n";
indent.indent -= 2;
}
dest << ")) {\n";
indent.indent -= 2;
dest << indent << ")) {\n";
indent.indent += 2;
dest << indent << "PyErr_Format(PyExc_ValueError, \"Internal error\");\n";
dest << indent << "return nullptr;\n";
Expand Down Expand Up @@ -430,7 +462,9 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
}
dest << indent << "if (result != 0) {\n";
indent.indent += 2;
dest << indent << "PyErr_Format(PyExc_ValueError, \"Halide error %d\", result);\n";
dest << indent << "std::string take;\n";
dest << indent << "std::swap(take, Halide::PythonRuntime::current_error);\n";
dest << indent << "PyErr_Format(PyExc_RuntimeError, \"Halide Runtime Error: %d (%s)\", result, take.c_str());\n";
dest << indent << "return nullptr;\n";
indent.indent -= 2;
dest << indent << "}\n";
Expand Down