Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error handling for self-referential types #144

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
22 changes: 20 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use-case, consider pre-allocating the `CoDual`s and calling the other method of
function. The `CoDual`s should be primal-tangent pairs (as opposed to primal-fdata pairs).
"""
function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any, N}) where {R, N}
return __value_and_pullback!!(rule, ȳ, tuple_map(zero_codual, fx)...)
return __value_and_pullback!!(rule, ȳ, __create_coduals(fx)...)
end

"""
Expand All @@ -71,5 +71,23 @@ end
Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`.
"""
function value_and_gradient!!(rule::R, fx::Vararg{Any, N}) where {R, N}
return __value_and_gradient!!(rule, tuple_map(zero_codual, fx)...)
return __value_and_gradient!!(rule, __create_coduals(fx)...)
end

function __create_coduals(args)
try
return tuple_map(zero_codual, args)
catch e
if e isa StackOverflowError
error(
"Found a StackOverFlow error when trying to wrap inputs. This often " *
"means that Tapir.jl has encountered a self-referential type. Tapir.jl " *
"is not presently able to handle self-referential types, so if you are " *
"indeed using a self-referential type somewhere, you will need to " *
"refactor to avoid it if you wish to use Tapir.jl."
)
else
rethrow(e)
end
end
end
15 changes: 15 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
mutable struct MutableSelfRef
x::Any
end

@testset "interface" begin
@testset "$(typeof((f, x...)))" for (ȳ, f, x...) in Any[
(1.0, (x, y) -> x * y + sin(x) * cos(y), 5.0, 4.0),
Expand All @@ -17,4 +21,15 @@
rule = build_rrule(foo, 5.0)
@test_throws ArgumentError value_and_pullback!!(rule, 1.0, foo, CoDual(5.0, 0.0))
end
@testset "sensible error occurs when self-reference found" begin
rule = build_rrule(Tapir.PInterp(), Tuple{typeof(identity), MutableSelfRef})
v = MutableSelfRef(nothing)
v.x = v

# Check that zero_tangent for v does indeed cause a stack overflow.
@test_throws StackOverflowError zero_tangent(v)

# Check that we're catching the stack overflow.
@test_throws ErrorException value_and_pullback!!(rule, identity, v)
end
end
Loading