Skip to content

Commit

Permalink
Add a HalideError base class to Python bindings (#6750)
Browse files Browse the repository at this point in the history
* Add a `HalideError` base class to Python bindings

Per suggestion from @alexreinking, this remaps all exceptions thrown by the Halide Python bindings to be `halide.HalideError` (or a subclass thereof), rather than plain old `RuntimeError`.

* Remove scalpel left in patient

* Don't use a subclass for PyStub error handling
  • Loading branch information
steven-johnson authored May 6, 2022
1 parent 6fbf203 commit 47d8103
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 39 deletions.
20 changes: 10 additions & 10 deletions python_bindings/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_compiletime_error():
buf = hl.Buffer(hl.UInt(8), [2, 2])
try:
f.realize(buf)
except RuntimeError as e:
except hl.HalideError as e:
assert 'Output buffer f has type uint16 but type of the buffer passed in is uint8' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -25,7 +25,7 @@ def test_runtime_error():
buf = hl.Buffer(hl.UInt(8), [10])
try:
f.realize(buf)
except RuntimeError as e:
except hl.HalideError as e:
assert 'do not cover required region' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_basics2():

try:
val1 = clamped[x * s_sigma - s_sigma/2, y * s_sigma - s_sigma/2]
except RuntimeError as e:
except hl.HalideError as e:
assert 'Implicit cast from float32 to int' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand Down Expand Up @@ -317,21 +317,21 @@ def test_typed_funcs():
assert not f.defined()
try:
assert f.output_type() == Int(32)
except RuntimeError as e:
except hl.HalideError as e:
assert 'it is undefined' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
assert f.outputs() == 0
except RuntimeError as e:
except hl.HalideError as e:
assert 'it is undefined' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
assert f.dimensions() == 0
except RuntimeError as e:
except hl.HalideError as e:
assert 'it is undefined' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -348,7 +348,7 @@ def test_typed_funcs():
assert not f.defined()
try:
assert f.output_type() == hl.Int(32)
except RuntimeError as e:
except hl.HalideError as e:
assert 'it returns a Tuple' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -361,7 +361,7 @@ def test_typed_funcs():
try:
f[x, y] = hl.i32(0);
f.realize([10, 10])
except RuntimeError as e:
except hl.HalideError as e:
assert 'is constrained to have exactly 1 dimensions, but is defined with 2 dimensions' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -370,7 +370,7 @@ def test_typed_funcs():
try:
f[x, y] = hl.i16(0);
f.realize([10, 10])
except RuntimeError as e:
except hl.HalideError as e:
assert 'is constrained to only hold values of type int32 but is defined with values of type int16' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -379,7 +379,7 @@ def test_typed_funcs():
try:
f[x, y] = (hl.i16(0), hl.f64(0))
f.realize([10, 10])
except RuntimeError as e:
except hl.HalideError as e:
assert 'is constrained to only hold values of type (int32, float32) but is defined with values of type (int16, float64)' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand Down
12 changes: 6 additions & 6 deletions python_bindings/correctness/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def test_extern():

try:
sort_func.compile_jit()
except RuntimeError:
pass
except hl.HalideError:
assert 'cannot be converted to a bool' in str(e)
else:
raise Exception("compile_jit should have raised a 'Symbol not found' RuntimeError")
assert False, 'Did not see expected exception!'


import ctypes
Expand All @@ -44,10 +44,10 @@ def test_extern():

try:
sort_func.compile_jit()
except RuntimeError:
print("ctypes CDLL did not work out")
except hl.HalideError:
assert 'cannot be converted to a bool' in str(e)
else:
print("ctypes CDLL worked !")
assert False, 'Did not see expected exception!'

lib_path = "the_sort_function.so"
#lib_path = "/home/rodrigob/code/references/" \
Expand Down
22 changes: 11 additions & 11 deletions python_bindings/correctness/pystub.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,47 +61,47 @@ def test_simple(gen):
try:
# Inputs w/ mixed by-position and by-name
f = gen(target, b_in, f_in, float_arg=3.5)
except RuntimeError as e:
except hl.HalideError as e:
assert 'Cannot use both positional and keyword arguments for inputs.' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# too many positional args
f = gen(target, b_in, f_in, 3.5, 4)
except RuntimeError as e:
except hl.HalideError as e:
assert 'Expected exactly 3 positional args for inputs, but saw 4.' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# too few positional args
f = gen(target, b_in, f_in)
except RuntimeError as e:
except hl.HalideError as e:
assert 'Expected exactly 3 positional args for inputs, but saw 2.' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Inputs that can't be converted to what the receiver needs (positional)
f = gen(target, hl.f32(3.141592), "happy", k)
except RuntimeError as e:
assert 'Unable to cast Python instance' in str(e)
except hl.HalideError as e:
assert 'Input func_input requires a Param (or scalar literal) argument' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Inputs that can't be converted to what the receiver needs (named)
f = gen(target, b_in, f_in, float_arg="bogus")
except RuntimeError as e:
assert 'Unable to cast Python instance' in str(e)
except hl.HalideError as e:
assert 'Input float_arg requires a Param (or scalar literal) argument' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Input specified by both pos and kwarg
f = gen(target, b_in, f_in, 3.5, float_arg=4.5)
except RuntimeError as e:
except hl.HalideError as e:
assert "Cannot use both positional and keyword arguments for inputs." in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -117,23 +117,23 @@ def test_simple(gen):
try:
# Bad gp name
f = gen(target, b_in, f_in, 3.5, generator_params={"foo": 0})
except RuntimeError as e:
except hl.HalideError as e:
assert "has no GeneratorParam named: foo" in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Bad input name
f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, funk_input=f_in)
except RuntimeError as e:
except hl.HalideError as e:
assert "Unknown input 'funk_input' specified via keyword argument." in str(e)
else:
assert False, 'Did not see expected exception!'

try:
# Bad gp name
f = gen(target, buffer_input=b_in, float_arg=3.5, generator_params=gp, func_input=f_in, nonexistent_generator_param="wat")
except RuntimeError as e:
except hl.HalideError as e:
assert "Unknown input 'nonexistent_generator_param' specified via keyword argument." in str(e)
else:
assert False, 'Did not see expected exception!'
Expand Down
4 changes: 2 additions & 2 deletions python_bindings/correctness/tuple_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_tuple_select():
f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y),
x + y < 100, (x-1, y-2),
(x-100, y-200))
except RuntimeError as e:
except hl.HalideError as e:
assert 'tuple_select() may not mix Expr and Tuple for the condition elements.' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand All @@ -73,7 +73,7 @@ def test_tuple_select():
try:
f = hl.Func('f')
f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y, 0), (1, 2, 3, 4))
except RuntimeError as e:
except hl.HalideError as e:
assert 'tuple_select() requires all Tuples to have identical sizes' in str(e)
else:
assert False, 'Did not see expected exception!'
Expand Down
11 changes: 11 additions & 0 deletions python_bindings/src/PyError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ void define_error(py::module &m) {
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

static py::exception<Error> halide_error(m, "HalideError");
py::register_exception_translator([](std::exception_ptr p) { // NOLINT
try {
if (p) {
std::rethrow_exception(p);
}
} catch (const Error &e) {
halide_error(e.what());
}
});
}

} // namespace PythonBindings
Expand Down
76 changes: 66 additions & 10 deletions python_bindings/stub/PyStubImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ void halide_python_error(JITUserContext *, const char *msg) {
}

void halide_python_print(JITUserContext *, const char *msg) {
py::gil_scoped_acquire acquire;
py::print(msg, py::arg("end") = "");
}

class HalidePythonCompileTimeErrorReporter : public CompileTimeErrorReporter {
public:
void warning(const char *msg) override {
py::gil_scoped_acquire acquire;
py::print(msg, py::arg("end") = "");
}

Expand All @@ -63,6 +65,21 @@ void install_error_handlers(py::module &m) {
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

static py::object halide_error = py::module_::import("halide").attr("HalideError");
if (halide_error.is(py::none())) {
throw std::runtime_error("Could not find halide.HalideError");
}

py::register_exception_translator([](std::exception_ptr p) { // NOLINT
try {
if (p) {
std::rethrow_exception(p);
}
} catch (const Error &e) {
PyErr_SetString(halide_error.ptr(), e.what());
}
});
}

// Anything that defines __getitem__ looks sequencelike to pybind,
Expand All @@ -71,33 +88,72 @@ bool is_real_sequence(const py::object &o) {
return py::isinstance<py::sequence>(o) && py::hasattr(o, "__len__");
}

StubInput to_stub_input(const py::object &o) {
template<typename T>
struct cast_error_string {
std::string operator()(const py::handle &h, const std::string &name) {
return "Unable to cast Input " + name + " to " + py::type_id<T>() + " from " + (std::string)py::str(py::type::handle_of(h));
}
};

template<>
std::string cast_error_string<Buffer<>>::operator()(const py::handle &h, const std::string &name) {
std::ostringstream o;
o << "Input " << name << " requires an ImageParam or Buffer argument when using generate(), but saw " << (std::string)py::str(py::type::handle_of(h));
return o.str();
}

template<>
std::string cast_error_string<Func>::operator()(const py::handle &h, const std::string &name) {
std::ostringstream o;
o << "Input " << name << " requires a Func argument when using generate(), but saw " << (std::string)py::str(py::type::handle_of(h));
return o.str();
}

template<>
std::string cast_error_string<Expr>::operator()(const py::handle &h, const std::string &name) {
std::ostringstream o;
o << "Input " << name << " requires a Param (or scalar literal) argument when using generate(), but saw " << (std::string)py::str(py::type::handle_of(h));
return o.str();
}

template<typename T>
T cast_to(const py::handle &h, const std::string &name) {
// We want to ensure that the error thrown is one that will be translated
// to `hl.Error` in Python.
try {
return h.cast<T>();
} catch (const std::exception &e) {
throw Halide::Error(cast_error_string<T>()(h, name));
}
}

StubInput to_stub_input(const py::object &o, const std::string &name) {
// Don't use isinstance: we want to get things that
// can be implicitly converted as well (eg ImageParam -> Func)
try {
return StubInput(StubInputBuffer(o.cast<Buffer<>>()));
return StubInput(StubInputBuffer(cast_to<Buffer<>>(o, name)));
} catch (...) {
// Not convertible to Buffer. Fall thru and try next.
}

try {
return StubInput(o.cast<Func>());
return StubInput(cast_to<Func>(o, name));
} catch (...) {
// Not convertible to Func. Fall thru and try next.
}

return StubInput(o.cast<Expr>());
return StubInput(cast_to<Expr>(o, name));
}

std::vector<StubInput> to_stub_inputs(const py::object &value) {
std::vector<StubInput> to_stub_inputs(const py::object &value, const std::string &name) {
if (is_real_sequence(value)) {
std::vector<StubInput> v;
for (const auto &o : py::reinterpret_borrow<py::sequence>(value)) {
v.push_back(to_stub_input(o));
v.push_back(to_stub_input(o, name));
}
return v;
} else {
return {to_stub_input(value)};
return {to_stub_input(value, name)};
}
}

Expand Down Expand Up @@ -158,7 +214,7 @@ py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext
auto it = kw_inputs.find(name);
_halide_user_assert(it != kw_inputs.end()) << "Unknown input '" << name << "' specified via keyword argument.";
_halide_user_assert(it->second.empty()) << "Generator Input named '" << it->first << "' was specified more than once.";
it->second = to_stub_inputs(py::cast<py::object>(value));
it->second = to_stub_inputs(py::cast<py::object>(value), name);
kw_inputs_specified++;
}

Expand All @@ -178,8 +234,8 @@ py::object generate_impl(const GeneratorFactory &factory, const GeneratorContext
<< "Cannot use both positional and keyword arguments for inputs.";
_halide_user_assert(args.size() == names.inputs.size())
<< "Expected exactly " << names.inputs.size() << " positional args for inputs, but saw " << args.size() << ".";
for (auto arg : args) {
inputs.push_back(to_stub_inputs(py::cast<py::object>(arg)));
for (size_t i = 0; i < args.size(); i++) {
inputs.push_back(to_stub_inputs(py::cast<py::object>(args[i]), names.inputs[i]));
}
}

Expand Down

0 comments on commit 47d8103

Please sign in to comment.