Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error-handling in Python Extensions #6986

Merged
merged 4 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python_bindings/test/correctness/multi_method_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,21 @@ def test_user_context():
multi_method_module.user_context(None, ord('q'), output)
assert output == bytearray("qqqq", "ascii")

def test_aot_call_failure_throws_exception():
buffer_input = np.zeros([2, 2], dtype=np.uint8)
func_input = np.zeros([2, 2], dtype=np.float32) # wrong type
float_arg = 3.5
simple_output = np.zeros([2, 2], dtype=np.float32)

try:
multi_method_module.simple(buffer_input, func_input, float_arg, simple_output)
except RuntimeError as e:
assert 'Halide Runtime Error: -3 (Input buffer func_input has type uint8 but type of the buffer passed in is float32)' in str(e), str(e)
else:
assert False, 'Did not see expected exception, saw: ' + str(e)

if __name__ == "__main__":
test_simple()
test_user_context()
test_aot_call_failure_throws_exception()

51 changes: 45 additions & 6 deletions src/PythonExtensionGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PythonExtensionGen::PythonExtensionGen(std::ostream &dest)
}

void PythonExtensionGen::compile(const Module &module) {
dest << "#include <string>\n";
dest << "#include <Python.h>\n";
dest << "#include \"HalideRuntime.h\"\n\n";

Expand Down Expand Up @@ -269,6 +270,12 @@ namespace Halide::PythonExtensions {
#undef X
} // namespace Halide::PythonExtensions

#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS
namespace Halide::PythonRuntime {
thread_local std::string current_error;
} // namespace Halide::PythonRuntime
#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS

namespace {

#define _HALIDE_STRINGIFY(x) #x
Expand All @@ -295,17 +302,38 @@ PyModuleDef _moduledef = {
nullptr, // free
};

#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS
void _module_halide_error(void *user_context, const char *msg) {
using Halide::PythonRuntime::current_error;
if (current_error.empty()) {
// fprintf(stderr, "Setting current_error=(%s)\n", msg);
current_error = msg;
} else {
// fprintf(stderr, "Warning: error (%s) ignored because current_error=(%s)\n", msg, current_error.c_str());
}
}

void _module_halide_print(void *user_context, const char *msg) {
PySys_FormatStdout("%s", msg);
}
#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS

} // namespace

extern "C" {

HALIDE_EXPORT_SYMBOL PyObject *_HALIDE_EXPAND_AND_CONCAT(PyInit_, HALIDE_PYTHON_EXTENSION_MODULE)() {
return PyModule_Create(&_moduledef);
PyObject *m = PyModule_Create(&_moduledef);
#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS
halide_set_error_handler(_module_halide_error);
halide_set_custom_print(_module_halide_print);
#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS
return m;
}

} // extern "C"

#endif //HALIDE_PYTHON_EXTENSION_OMIT_MODULE_DEFINITION
#endif // HALIDE_PYTHON_EXTENSION_OMIT_MODULE_DEFINITION
)INLINE_CODE";
}

Expand All @@ -323,9 +351,14 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {

dest << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_FUNCTION_DEFINITIONS\n";
dest << "\n";
dest << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n";
dest << "namespace Halide::PythonRuntime {\n";
dest << "extern thread_local std::string current_error;\n";
dest << "} // namespace Halide::PythonRuntime\n";
dest << "#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n";
dest << "\n";
dest << "namespace Halide::PythonExtensions {\n";
dest << "\n";

dest << "namespace {\n";
dest << "\n";
dest << indent << "const char* const " << basename << "_kwlist[] = {\n";
Expand Down Expand Up @@ -369,12 +402,12 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
dest << print_type(&arg).first;
}
dest << "\", (char**)" << basename << "_kwlist\n";
indent.indent += 2;
for (size_t i = 0; i < args.size(); i++) {
indent.indent += 2;
dest << indent << ", &py_" << arg_names[i] << "\n";
indent.indent -= 2;
}
dest << ")) {\n";
indent.indent -= 2;
dest << indent << ")) {\n";
indent.indent += 2;
dest << indent << "PyErr_Format(PyExc_ValueError, \"Internal error\");\n";
dest << indent << "return nullptr;\n";
Expand Down Expand Up @@ -430,7 +463,13 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
}
dest << indent << "if (result != 0) {\n";
indent.indent += 2;
dest << indent << "#ifndef HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n";
dest << indent << "std::string take;\n";
dest << indent << "std::swap(take, Halide::PythonRuntime::current_error);\n";
dest << indent << "PyErr_Format(PyExc_RuntimeError, \"Halide Runtime Error: %d (%s)\", result, take.c_str());\n";
dest << indent << "#else\n";
dest << indent << "PyErr_Format(PyExc_ValueError, \"Halide error %d\", result);\n";
dest << indent << "#endif // HALIDE_PYTHON_EXTENSION_OMIT_ERROR_AND_PRINT_HANDLERS\n";
dest << indent << "return nullptr;\n";
indent.indent -= 2;
dest << indent << "}\n";
Expand Down