Skip to content

Commit

Permalink
Remove use of custom isequal to compare scenarios (#486)
Browse files Browse the repository at this point in the history
* Remove use of custom `isequal` to compare scenarios

* Fix

* Scenario intact toggle

* No testing equality of f
  • Loading branch information
gdalle authored Sep 23, 2024
1 parent 1369010 commit 0a0a943
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 46 deletions.
3 changes: 2 additions & 1 deletion DifferentiationInterface/test/Down/Flux/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ test_differentiation(
# AutoEnzyme() # TODO: fix
],
DIT.flux_scenarios();
isequal=DIT.flux_isequal,
isapprox=DIT.flux_isapprox,
rtol=1e-2,
atol=1e-6,
scenario_intact=false, # TODO: why?
logging=LOGGING,
)
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Down/Lux/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Random.seed!(0)
test_differentiation(
AutoZygote(),
DIT.lux_scenarios(Random.Xoshiro(63));
isequal=DIT.lux_isequal,
isapprox=DIT.lux_isapprox,
rtol=1.0f-2,
atol=1.0f-3,
scenario_intact=false, # TODO: why?
logging=LOGGING,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ function gradient_finite_differences(loss, model)
return re(only(gs))
end

function DIT.flux_isequal(a, b)
return all(isequal.(fleaves(a), fleaves(b)))
end

function DIT.flux_isapprox(a, b; atol, rtol)
isapprox_results = fmapstructure_with_path(a, b) do kp, x, y
if :state in kp # ignore RNN and LSTM state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ Relevant discussions:
- https://github.com/LuxDL/Lux.jl/issues/769
=#

function DIT.lux_isequal(a, b)
return check_approx(a, b; atol=0, rtol=0)
end

function DIT.lux_isapprox(a, b; atol, rtol)
return check_approx(a, b; atol, rtol)
end
Expand Down
15 changes: 15 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/scenario.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ function Scenario{op,pl_op}(
return Scenario{op,pl_op,:in}(f!; x, y, tang, contexts, res1, res2)
end

Base.:(==)(scen1::Scenario, scen2::Scenario) = false

function Base.:(==)(
scen1::Scenario{op,pl_op,pl_fun}, scen2::Scenario{op,pl_op,pl_fun}
) where {op,pl_op,pl_fun}
eq_f = scen1.f == scen2.f
eq_x = scen1.x == scen2.x
eq_y = scen1.y == scen2.y
eq_tang = scen1.tang == scen2.tang
eq_contexts = scen1.contexts == scen2.contexts
eq_res1 = scen1.res1 == scen2.res1
eq_res2 = scen1.res2 == scen2.res2
return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2)
end

operator(::Scenario{op}) where {op} = op
operator_place(::Scenario{op,pl_op}) where {op,pl_op} = pl_op
function_place(::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} = pl_fun
Expand Down
6 changes: 3 additions & 3 deletions DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ Filtering:
Options:
- `logging=false`: whether to log progress
- `isequal=isequal`: function used to compare objects exactly, with the standard signature `isequal(x, y)`
- `isapprox=isapprox`: function used to compare objects approximately, with the standard signature `isapprox(x, y; atol, rtol)`
- `atol=0`: absolute precision for correctness testing (when comparing to the reference outputs)
- `rtol=1e-3`: relative precision for correctness testing (when comparing to the reference outputs)
- `scenario_intact=true`: whether to check that the scenario remains unchanged after the operators are applied
"""
function test_differentiation(
backends::Vector{<:AbstractADType},
Expand All @@ -63,10 +63,10 @@ function test_differentiation(
excluded::Vector{Symbol}=Symbol[],
# options
logging::Bool=false,
isequal=isequal,
isapprox=isapprox,
atol::Real=0,
rtol::Real=1e-3,
scenario_intact::Bool=true,
)
scenarios = filter_scenarios(
scenarios; first_order, second_order, input_type, output_type, excluded
Expand Down Expand Up @@ -109,7 +109,7 @@ function test_differentiation(
],
)
correctness && @testset "Correctness" begin
test_correctness(backend, scen; isequal, isapprox, atol, rtol)
test_correctness(backend, scen; isapprox, atol, rtol, scenario_intact)
end
type_stability && @testset "Type stability" begin
@static if VERSION >= v"1.7"
Expand Down
55 changes: 24 additions & 31 deletions DifferentiationInterfaceTest/src/tests/correctness_eval.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
function test_scen_intact(new_scen, scen; isequal)
for n in fieldnames(typeof(scen))
n == :f && continue
@test isequal(getfield(new_scen, n), getfield(scen, n))
end
end

for op in [
:derivative,
:gradient,
Expand Down Expand Up @@ -55,10 +48,10 @@ for op in [
@eval function test_correctness(
ba::AbstractADType,
scen::$S1out;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
xrand = myrandom(x)
Expand All @@ -80,17 +73,17 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S1in;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
xrand = myrandom(x)
Expand Down Expand Up @@ -128,7 +121,7 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

Expand All @@ -137,10 +130,10 @@ for op in [
@eval function test_correctness(
ba::AbstractADType,
scen::$S2out;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
xrand, yrand = myrandom(x), myrandom(y)
Expand Down Expand Up @@ -172,17 +165,17 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S2in;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
xrand, yrand = myrandom(x), myrandom(y)
Expand Down Expand Up @@ -222,18 +215,18 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

elseif op in [:second_derivative, :hessian]
@eval function test_correctness(
ba::AbstractADType,
scen::$S1out;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, res1, res2, contexts) = new_scen = deepcopy(scen)
xrand = myrandom(x)
Expand Down Expand Up @@ -261,17 +254,17 @@ for op in [
@test res2_out2_noval scen.res2
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S1in;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, res1, res2, contexts) = new_scen = deepcopy(scen)
xrand = myrandom(x)
Expand Down Expand Up @@ -313,18 +306,18 @@ for op in [
@test res2_out2_noval scen.res2
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

elseif op in [:pushforward, :pullback]
@eval function test_correctness(
ba::AbstractADType,
scen::$S1out;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
xrand, tangrand = myrandom(x), myrandom(tang)
Expand Down Expand Up @@ -354,17 +347,17 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S1in;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
xrand, tangrand = myrandom(x), myrandom(tang)
Expand Down Expand Up @@ -406,17 +399,17 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S2out;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang)
Expand Down Expand Up @@ -456,17 +449,17 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S2in;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang)
Expand Down Expand Up @@ -510,18 +503,18 @@ for op in [
@test res1_out2_noval scen.res1
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

elseif op in [:hvp]
@eval function test_correctness(
ba::AbstractADType,
scen::$S1out;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen)
xrand, tangrand = myrandom(x), myrandom(tang)
Expand All @@ -539,17 +532,17 @@ for op in [
@test res2_out2_noval scen.res2
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end

@eval function test_correctness(
ba::AbstractADType,
scen::$S1in;
isequal::Function,
isapprox::Function,
atol::Real,
rtol::Real,
scenario_intact::Bool,
)
@compat (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen)
xrand, tangrand = myrandom(x), myrandom(tang)
Expand All @@ -575,7 +568,7 @@ for op in [
@test res2_out2_noval scen.res2
end
end
test_scen_intact(new_scen, scen; isequal)
scenario_intact && @test new_scen == scen
return nothing
end
end
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterfaceTest/test/weird.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@ Random.seed!(0)
test_differentiation(
AutoZygote(),
DIT.flux_scenarios();
isequal=DIT.flux_isequal,
isapprox=DIT.flux_isapprox,
rtol=1e-2,
atol=1e-6,
scenario_intact=false,
logging=LOGGING,
)

#=
test_differentiation(
AutoZygote(),
DIT.lux_scenarios(Random.Xoshiro(63));
isequal=DIT.lux_isequal,
isapprox=DIT.lux_isapprox,
rtol=1.0f-2,
atol=1.0f-3,
scenario_intact=false,
logging=LOGGING,
)
=#

0 comments on commit 0a0a943

Please sign in to comment.