Skip to content

Commit

Permalink
Fix turbo macro for isfinite checks
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Oct 27, 2022
1 parent 224a567 commit b3a5fdc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 18 deletions.
19 changes: 11 additions & 8 deletions src/EvaluateEquation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ function deg1_l2_ll0_lr0_eval(
cumulator = Array{T,1}(undef, n)
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
x_l = op_l(val_ll, cX[feature_lr, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf) # These will get discovered by _eval_tree_array at end.
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return (cumulator, true)
Expand Down Expand Up @@ -273,6 +273,7 @@ function deg1_l1_ll0_eval(
end
end

# op(x, y) for x and y variable/constant
function deg2_l0_r0_eval(
tree::Node{T},
cX::AbstractMatrix{T},
Expand Down Expand Up @@ -320,6 +321,7 @@ function deg2_l0_r0_eval(
return (cumulator, true)
end

# op(x, y) for x variable/constant, y arbitrary
function deg2_l0_eval(
tree::Node{T},
cX::AbstractMatrix{T},
Expand Down Expand Up @@ -349,6 +351,7 @@ function deg2_l0_eval(
return (cumulator, true)
end

# op(x, y) for x arbitrary, y variable/constant
function deg2_r0_eval(
tree::Node{T},
cX::AbstractMatrix{T},
Expand Down Expand Up @@ -520,9 +523,9 @@ function eval(current_node)
function eval_tree_array(
tree::Node, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true
)
!throw_errors && return _eval_tree_array(tree, cX, operators, Val(false))
!throw_errors && return _eval_tree_array_generic(tree, cX, operators, Val(false))
try
return _eval_tree_array(tree, cX, operators, Val(true))
return _eval_tree_array_generic(tree, cX, operators, Val(true))
catch e
tree_s = string_tree(tree, operators)
error_msg = "Failed to evaluate tree $(tree_s)."
Expand All @@ -537,7 +540,7 @@ function eval_tree_array(
end
end

function _eval_tree_array(
function _eval_tree_array_generic(
tree::Node{T1},
cX::AbstractArray{T2,N},
operators::GenericOperatorEnum,
Expand All @@ -554,13 +557,13 @@ function _eval_tree_array(
end
end
elseif tree.degree == 1
return deg1_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
return deg1_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
else
return deg2_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
return deg2_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
end
end

function deg1_eval(
function deg1_eval_generic(
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
) where {op_idx,throw_errors}
left, complete = eval_tree_array(tree.l, cX, operators)
Expand All @@ -570,7 +573,7 @@ function deg1_eval(
return op(left), true
end

function deg2_eval(
function deg2_eval_generic(
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
) where {op_idx,throw_errors}
left, complete = eval_tree_array(tree.l, cX, operators)
Expand Down
17 changes: 15 additions & 2 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,22 @@ function _remove_type_assertions(ex::Expr)
return Expr(ex.head, map(_remove_type_assertions, ex.args)...)
end
end
_remove_type_assertions(ex) = ex

function _remove_type_assertions(ex)
return ex
"""Replace instances of (isfinite(x) ? op(x) : T(Inf)) with op(x)"""
function _remove_isfinite(ex::Expr)
if (
ex.head == :if &&
length(ex.args) == 3 &&
ex.args[1].head == :call &&
ex.args[1].args[1] == :isfinite
)
return _remove_isfinite(ex.args[2])
else
return Expr(ex.head, map(_remove_isfinite, ex.args)...)
end
end
_remove_isfinite(ex) = ex

"""
@maybe_turbo use_turbo expression
Expand All @@ -26,6 +38,7 @@ This will also remove all type assertions from the expression.
macro maybe_turbo(turboflag, ex)
# Thanks @jlapeyre https://discourse.julialang.org/t/optional-macro-invocation/18588
clean_ex = _remove_type_assertions(ex)
clean_ex = _remove_isfinite(clean_ex)
turbo_ex = Expr(:macrocall, Symbol("@turbo"), LineNumberNode(@__LINE__), clean_ex)
simple_ex = Expr(
:macrocall,
Expand Down
29 changes: 21 additions & 8 deletions test/test_evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ using Test
include("test_params.jl")

# Test simple evaluations:
operators = OperatorEnum(;
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
)

functions = [
# deg2_l0_r0_eval
(x1, x2, x3) -> x1 * x2,
Expand Down Expand Up @@ -39,7 +35,9 @@ functions = [
(x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0,
]

for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
for turbo in [false, true],
T in [Float16, Float32, Float64],
(i_func, fnc) in enumerate(functions)

# Float16 not implemented:
turbo && T == Float16 && continue
Expand All @@ -53,7 +51,11 @@ for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
nodefnc = fnc
end

local tree, X
local tree, operators, X
operators = OperatorEnum(;
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
)
@extend_operators operators
tree = nodefnc(Node("x1"), Node("x2"), Node("x3"))
tree = convert(Node{T}, tree)

Expand All @@ -65,14 +67,25 @@ for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
true_y = realfnc.(X[1, :], X[2, :], X[3, :])

zero_tolerance = (T == Float16 ? 1e-4 : 1e-6)
@test all(abs.(test_y .- true_y) / N .< zero_tolerance)
try
@test all(abs.(test_y .- true_y) / N .< zero_tolerance)
catch
println("Test for type $T and turbo=$turbo and function $i_func $tree failed.")
mse = sum((x,) -> x^2, test_y .- true_y) / N
mean = sum(test_y) / N
stdev = sqrt(sum((x,) -> x^2, true_y .- mean) / N)
println("Relative error: $(mse / stdev)")
end
end

for turbo in [false, true], T in [Float16, Float32, Float64]
turbo && T == Float16 && continue
# Test specific branches of evaluation code:
# op(op(<constant>))
local tree
local tree, operators
operators = OperatorEnum(;
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
)
tree = Node(1, Node(1, Node(; val=3.0f0)))
@test repr(tree) == "cos(cos(3.0))"
tree = convert(Node{T}, tree)
Expand Down

0 comments on commit b3a5fdc

Please sign in to comment.