Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

By default, throw errors from MethodError #5

Merged
merged 2 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions src/EvaluateEquation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module EvaluateEquationModule

import ..EquationModule: Node
import ..EquationModule: Node, string_tree
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: @return_on_false, is_bad_array, vals
import ..EquationUtilsModule: is_constant
Expand Down Expand Up @@ -47,8 +47,12 @@ function eval(current_node)
The bulk of the code is for optimizations and pre-emptive NaN/Inf checks,
which speed up evaluation significantly.

# Returns
# Arguments
- `tree::Node`: The root node of the tree to evaluate.
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
- `operators::OperatorEnum`: The operators used in the tree.

# Returns
- `(output, complete)::Tuple{AbstractVector{T}, Bool}`: the result,
which is a 1D array, as well as if the evaluation completed
successfully (true/false). A `false` complete means an infinity
Expand Down Expand Up @@ -461,19 +465,51 @@ function eval(current_node)
return current_node.operator(eval(current_node.left_child), eval(current_node.right_child))
```


# Arguments
- `tree::Node`: The root node of the tree to evaluate.
- `cX::AbstractArray{T,N}`: The input data to evaluate the tree on.
- `operators::GenericOperatorEnum`: The operators used in the tree.
- `throw_errors::Bool=true`: Whether to throw errors
if they occur during evaluation. Otherwise,
MethodErrors will be caught before they happen and
evaluation will return `nothing`,
rather than throwing an error. This is useful in cases
where you are unsure if a particular tree is valid or not,
and would prefer to work with `nothing` as an output.

# Returns

- `(output, complete)::Tuple{Any, Bool}`: the result,
as well as if the evaluation completed successfully (true/false).
If evaluation failed, `nothing` will be returned for the first argument.
A `false` complete means an operator was called on input types
that it was not defined for.
"""
function eval_tree_array(
tree::Node{T1}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum
) where {T1,T2,N}
tree::Node, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true
)
!throw_errors && return _eval_tree_array(tree, cX, operators, Val(false))
try
return _eval_tree_array(tree, cX, operators, Val(true))
catch e
tree_s = string_tree(tree, operators)
error_msg = "Failed to evaluate tree $(tree_s)."
if isa(e, MethodError)
error_msg *= (
" Note that you can efficiently skip MethodErrors" *
" beforehand by passing `throw_errors=false` to " *
" `eval_tree_array`."
)
end
throw(ErrorException(error_msg))
end
end

function _eval_tree_array(
tree::Node{T1},
cX::AbstractArray{T2,N},
operators::GenericOperatorEnum,
::Val{throw_errors},
) where {T1,T2,N,throw_errors}
if tree.degree == 0
if tree.constant
return (tree.val::T1), true
Expand All @@ -485,27 +521,33 @@ function eval_tree_array(
end
end
elseif tree.degree == 1
return deg1_eval(tree, cX, vals[tree.op], operators)
return deg1_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
else
return deg2_eval(tree, cX, vals[tree.op], operators)
return deg2_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
end
end

function deg1_eval(tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum) where {op_idx}
function deg1_eval(
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
) where {op_idx,throw_errors}
left, complete = eval_tree_array(tree.l, cX, operators)
!complete && return nothing, false
!throw_errors && !complete && return nothing, false
op = operators.unaops[op_idx]
!hasmethod(op, Tuple{typeof(left)}) && return nothing, false
!throw_errors && !hasmethod(op, Tuple{typeof(left)}) && return nothing, false
return op(left), true
end

function deg2_eval(tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum) where {op_idx}
function deg2_eval(
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
) where {op_idx,throw_errors}
left, complete = eval_tree_array(tree.l, cX, operators)
!complete && return nothing, false
!throw_errors && !complete && return nothing, false
right, complete = eval_tree_array(tree.r, cX, operators)
!complete && return nothing, false
!throw_errors && !complete && return nothing, false
op = operators.binops[op_idx]
!hasmethod(op, Tuple{typeof(left),typeof(right)}) && return nothing, false
!throw_errors &&
!hasmethod(op, Tuple{typeof(left),typeof(right)}) &&
return nothing, false
return op(left, right), true
end

Expand Down
6 changes: 4 additions & 2 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ function GenericOperatorEnum(;
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))

function (tree::Node)(X)
out, did_finish = eval_tree_array(tree, X, $operators)
function (tree::Node)(X; throw_errors::Bool=true)
out, did_finish = eval_tree_array(
tree, X, $operators; throw_errors=throw_errors
)
if !did_finish
return nothing
end
Expand Down
42 changes: 42 additions & 0 deletions test/test_error_handling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using DynamicExpressions
using Test

# Test that we generate errors:
baseT = Float64
T = Union{baseT,Vector{baseT},Matrix{baseT}}

scalar_add(x::T, y::T) where {T<:Real} = x + y

operators = GenericOperatorEnum(; binary_operators=[scalar_add], extend_user_operators=true)

x1, x2, x3 = [Node(T; feature=i) for i in 1:3]

tree = Node(1, x1, x2)

# With error handling:
try
eval_tree_array(tree, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=true)
@test false
catch e
@test isa(e, ErrorException)
expected_error_msg = "Failed to evaluate tree"
@test occursin(expected_error_msg, e.msg)
end

# Without error handling:
output, flag = eval_tree_array(
tree, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=false
)
@test output === nothing
@test !flag

# Default is to catch errors:
try
tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@test false
catch e
@test isa(e, ErrorException)
end

# But can be overrided:
output = tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; throw_errors=false)
4 changes: 4 additions & 0 deletions test/unittest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ end
@safetestset "Test tensor operators" begin
include("test_tensor_operators.jl")
end

@safetestset "Test error handling" begin
include("test_error_handling.jl")
end