diff --git a/src/EvaluateEquation.jl b/src/EvaluateEquation.jl index dcf64f49..d7c713a6 100644 --- a/src/EvaluateEquation.jl +++ b/src/EvaluateEquation.jl @@ -1,6 +1,6 @@ module EvaluateEquationModule -import ..EquationModule: Node +import ..EquationModule: Node, string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: @return_on_false, is_bad_array, vals import ..EquationUtilsModule: is_constant @@ -47,8 +47,12 @@ function eval(current_node) The bulk of the code is for optimizations and pre-emptive NaN/Inf checks, which speed up evaluation significantly. -# Returns +# Arguments +- `tree::Node`: The root node of the tree to evaluate. +- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. +- `operators::OperatorEnum`: The operators used in the tree. +# Returns - `(output, complete)::Tuple{AbstractVector{T}, Bool}`: the result, which is a 1D array, as well as if the evaluation completed successfully (true/false). A `false` complete means an infinity @@ -461,10 +465,19 @@ function eval(current_node) return current_node.operator(eval(current_node.left_child), eval(current_node.right_child)) ``` - +# Arguments +- `tree::Node`: The root node of the tree to evaluate. +- `cX::AbstractArray{T,N}`: The input data to evaluate the tree on. +- `operators::GenericOperatorEnum`: The operators used in the tree. +- `throw_errors::Bool=true`: Whether to throw errors + if they occur during evaluation. Otherwise, + MethodErrors will be caught before they happen and + evaluation will return `nothing`, + rather than throwing an error. This is useful in cases + where you are unsure if a particular tree is valid or not, + and would prefer to work with `nothing` as an output. # Returns - - `(output, complete)::Tuple{Any, Bool}`: the result, as well as if the evaluation completed successfully (true/false). If evaluation failed, `nothing` will be returned for the first argument. @@ -472,8 +485,31 @@ function eval(current_node) that it was not defined for. """ function eval_tree_array( - tree::Node{T1}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum -) where {T1,T2,N} + tree::Node, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true +) + !throw_errors && return _eval_tree_array(tree, cX, operators, Val(false)) + try + return _eval_tree_array(tree, cX, operators, Val(true)) + catch e + tree_s = string_tree(tree, operators) + error_msg = "Failed to evaluate tree $(tree_s)." + if isa(e, MethodError) + error_msg *= ( + " Note that you can efficiently skip MethodErrors" * + " beforehand by passing `throw_errors=false` to " * + " `eval_tree_array`." + ) + end + throw(ErrorException(error_msg)) + end +end + +function _eval_tree_array( + tree::Node{T1}, + cX::AbstractArray{T2,N}, + operators::GenericOperatorEnum, + ::Val{throw_errors}, +) where {T1,T2,N,throw_errors} if tree.degree == 0 if tree.constant return (tree.val::T1), true @@ -485,27 +521,33 @@ function eval_tree_array( end end elseif tree.degree == 1 - return deg1_eval(tree, cX, vals[tree.op], operators) + return deg1_eval(tree, cX, vals[tree.op], operators, Val(throw_errors)) else - return deg2_eval(tree, cX, vals[tree.op], operators) + return deg2_eval(tree, cX, vals[tree.op], operators, Val(throw_errors)) end end -function deg1_eval(tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum) where {op_idx} +function deg1_eval( + tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors} +) where {op_idx,throw_errors} left, complete = eval_tree_array(tree.l, cX, operators) - !complete && return nothing, false + !throw_errors && !complete && return nothing, false op = operators.unaops[op_idx] - !hasmethod(op, Tuple{typeof(left)}) && return nothing, false + !throw_errors && !hasmethod(op, Tuple{typeof(left)}) && return nothing, false return op(left), true end -function deg2_eval(tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum) where {op_idx} +function deg2_eval( + tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors} +) where {op_idx,throw_errors} left, complete = eval_tree_array(tree.l, cX, operators) - !complete && return nothing, false + !throw_errors && !complete && return nothing, false right, complete = eval_tree_array(tree.r, cX, operators) - !complete && return nothing, false + !throw_errors && !complete && return nothing, false op = operators.binops[op_idx] - !hasmethod(op, Tuple{typeof(left),typeof(right)}) && return nothing, false + !throw_errors && + !hasmethod(op, Tuple{typeof(left),typeof(right)}) && + return nothing, false return op(left, right), true end diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 11bf733b..650aa58a 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -283,8 +283,10 @@ function GenericOperatorEnum(; Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators)) Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators)) - function (tree::Node)(X) - out, did_finish = eval_tree_array(tree, X, $operators) + function (tree::Node)(X; throw_errors::Bool=true) + out, did_finish = eval_tree_array( + tree, X, $operators; throw_errors=throw_errors + ) if !did_finish return nothing end diff --git a/test/test_error_handling.jl b/test/test_error_handling.jl new file mode 100644 index 00000000..e04498d2 --- /dev/null +++ b/test/test_error_handling.jl @@ -0,0 +1,42 @@ +using DynamicExpressions +using Test + +# Test that we generate errors: +baseT = Float64 +T = Union{baseT,Vector{baseT},Matrix{baseT}} + +scalar_add(x::T, y::T) where {T<:Real} = x + y + +operators = GenericOperatorEnum(; binary_operators=[scalar_add], extend_user_operators=true) + +x1, x2, x3 = [Node(T; feature=i) for i in 1:3] + +tree = Node(1, x1, x2) + +# With error handling: +try + eval_tree_array(tree, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=true) + @test false +catch e + @test isa(e, ErrorException) + expected_error_msg = "Failed to evaluate tree" + @test occursin(expected_error_msg, e.msg) +end + +# Without error handling: +output, flag = eval_tree_array( + tree, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=false +) +@test output === nothing +@test !flag + +# Default is to catch errors: +try + tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + @test false +catch e + @test isa(e, ErrorException) +end + +# But can be overrided: +output = tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; throw_errors=false) diff --git a/test/unittest.jl b/test/unittest.jl index 25d89399..b5b2380e 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -51,3 +51,7 @@ end @safetestset "Test tensor operators" begin include("test_tensor_operators.jl") end + +@safetestset "Test error handling" begin + include("test_error_handling.jl") +end