From b524b4e266f4f9bd09f186be69145f824915d006 Mon Sep 17 00:00:00 2001 From: Patrick Aschermayr Date: Wed, 13 Jul 2022 10:05:32 +0100 Subject: [PATCH] Update flatten --- Project.toml | 2 +- src/Core/flatten/construct.jl | 8 ++++++++ src/Differentiation/checks.jl | 13 +++++++------ src/Models/modelwrapper.jl | 3 ++- src/Models/tagged.jl | 3 ++- test/test-differentiation.jl | 2 +- test/test-flatten.jl | 3 +++ test/test-flatten/nested.jl | 18 ++++++++++++++++++ test/test-flatten/types.jl | 32 +++++++++++++++++++++++++++++++- 9 files changed, 73 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index a3a985f..9daba58 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelWrappers" uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6" authors = ["Patrick Aschermayr "] -version = "0.2.3" +version = "0.2.4" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/Core/flatten/construct.jl b/src/Core/flatten/construct.jl index 4d188a8..e68f75d 100644 --- a/src/Core/flatten/construct.jl +++ b/src/Core/flatten/construct.jl @@ -162,6 +162,10 @@ function flatten end function flatten(constructor::ReConstructor, x) return constructor.flatten.strict(x) end +function flatten(x) + constructor = ReConstructor(x) + return flatten(constructor, x), constructor +end """ $(FUNCTIONNAME)(x ) @@ -175,6 +179,10 @@ function flattenAD end function flattenAD(constructor::ReConstructor, x) return constructor.flatten.flexible(x) end +function flattenAD(x) + constructor = ReConstructor(x) + return flattenAD(constructor, x), constructor +end """ $(FUNCTIONNAME)(x ) diff --git a/src/Differentiation/checks.jl b/src/Differentiation/checks.jl index 1a9ddfa..b1f400e 100644 --- a/src/Differentiation/checks.jl +++ b/src/Differentiation/checks.jl @@ -96,30 +96,31 @@ $(TYPEDFIELDS) struct ObjectiveError <: Exception #!NOTE: Remove Parametric types so error message is shorter msg::String + ℓθᵤ::Real θ::NamedTuple θᵤ::AbstractVector - function ObjectiveError(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real} + function ObjectiveError(objective::Objective, ℓθᵤ::S, θᵤ::AbstractVector{T}) where {S<:Real, T<:Real} msg = "Internal error: leapfrog called from non-finite log density. Proposed parameter in constrained and unconstrained space:" θ = unflatten_constrain(objective.model, objective.tagged, θᵤ) - new(msg, θ, θᵤ) + new(msg, ℓθᵤ, θ, θᵤ) end end function checkfinite(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real} - ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, θᵤ) + ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, NaN, θᵤ) end function checkfinite(objective::Objective, result::T) where {T<:ℓObjectiveResult} - ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.θᵤ) + ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ) end function checkfinite( objective::Objective, result₀::T, result::T, min_Δ::Float64=min_Δ ) where {T<:ℓObjectiveResult} - ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.θᵤ) + ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ) end function checkfinite( objective::Objective, ℓθ₀::R, ℓθ::R, result::T, min_Δ::Float64=min_Δ ) where {R<:Real,T<:ℓObjectiveResult} - ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.θᵤ) + ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ) end ############################################################################################ diff --git a/src/Models/modelwrapper.jl b/src/Models/modelwrapper.jl index 496e6db..793c469 100644 --- a/src/Models/modelwrapper.jl +++ b/src/Models/modelwrapper.jl @@ -241,7 +241,8 @@ Inplace version of [`sample`](@ref). """ function sample!(_rng::Random.AbstractRNG, model::ModelWrapper) - ArgCheck.@argcheck _checkprior(model.info.constraint) "For inplace sample version, all constraints need to be a Distribution." + #!NOTE: Check no longer needed, as Fixed Tags just return current value +# ArgCheck.@argcheck _checkprior(model.info.constraint) "For inplace sample version, all constraints need to be a Distribution." model.val = sample(_rng, model) return nothing end diff --git a/src/Models/tagged.jl b/src/Models/tagged.jl index 8739874..c0f262a 100644 --- a/src/Models/tagged.jl +++ b/src/Models/tagged.jl @@ -104,7 +104,8 @@ end sample(model::ModelWrapper, tagged::Tagged) = sample(Random.GLOBAL_RNG, model, tagged) function sample!(_rng::Random.AbstractRNG, model::ModelWrapper, tagged::Tagged) - ArgCheck.@argcheck _checkprior(subset(tagged.info.constraint, tagged.parameter)) "For inplace sample version, all constraints need to be a Distribution." + #!NOTE: Check no longer needed, as Fixed Tags just return current value +# ArgCheck.@argcheck _checkprior(subset(tagged.info.constraint, tagged.parameter)) "For inplace sample version, all constraints need to be a Distribution." model.val = sample(_rng, model, tagged) return nothing end diff --git a/test/test-differentiation.jl b/test/test-differentiation.jl index 3c4afa8..3079e2a 100644 --- a/test/test-differentiation.jl +++ b/test/test-differentiation.jl @@ -39,7 +39,7 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx)) ModelWrappers.checkfinite(objectiveExample, _ld_inf, _ld_fin) ModelWrappers.checkfinite(objectiveExample, -Inf, 10.0, _ld_fin) - err = ObjectiveError(objectiveExample, theta_unconstrained) + err = ObjectiveError(objectiveExample, -Inf, theta_unconstrained) @test isa(err, ArgCheck.Exception) end diff --git a/test/test-flatten.jl b/test/test-flatten.jl index 25d2daa..0fca2c3 100644 --- a/test/test-flatten.jl +++ b/test/test-flatten.jl @@ -21,6 +21,8 @@ θ_flat = _flatten(θ) θ_unflat = _unflatten(θ_flat) + + #!NOTE Do not test if FlattenContinuous and empty Integer Param struct is evaluated if θ_flat isa Vector{T} where {T<:AbstractFloat} if unflat isa UnflattenStrict @@ -87,6 +89,7 @@ end param = _params[sym] θ = _get_val(param) constraint = _get_constraint(param) + ## Check all flatten possibilities for unflat in unflattenmethods for flat in flattentypes for floattypes in outputtypes diff --git a/test/test-flatten/nested.jl b/test/test-flatten/nested.jl index e2618d6..128a4fc 100644 --- a/test/test-flatten/nested.jl +++ b/test/test-flatten/nested.jl @@ -3,6 +3,12 @@ ############################################################################################ @testset "Nested - AbstractArray" begin + val = [1., [2, 3], [4. 5. ; 6. 7.], 8., [9., 10.]] +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) @@ -95,6 +101,12 @@ end ############################################################################################ @testset "Nested - Tuple" begin + val = (1., [2, 3], [4. 5. ; 6. 7.], 8, [9., 10.]) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) @@ -187,6 +199,12 @@ end ############################################################################################ @testset "Nested - NamedTuple" begin + val = (a = Float16(1.0), b = [2, 3], c = [4. 5. ; 6. 7.], d = 8, e = [9., 10.], f = (g = (h = 3.))) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) diff --git a/test/test-flatten/types.jl b/test/test-flatten/types.jl index df384a1..9dc8d59 100644 --- a/test/test-flatten/types.jl +++ b/test/test-flatten/types.jl @@ -3,10 +3,16 @@ ############################################################################################ @testset "Types - Float" begin + val = Float16(1.0) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes - flatdefault = FlattenDefault(; output = output, flattentype = flattentype) val = Float16(1.0) + flatdefault = FlattenDefault(; output = output, flattentype = flattentype) ReConstructor(val) reconstruct = ReConstructor(flatdefault, val) # Flatten @@ -94,6 +100,12 @@ end ############################################################################################ @testset "Types - Vector Float" begin + val = Float16.([1., 2.]) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) @@ -196,6 +208,12 @@ end ############################################################################################ @testset "Types - Array Float" begin + val = Float16.([1. 0.3 ; .3 1.0]) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) @@ -298,6 +316,12 @@ end ############################################################################################ @testset "Types - Integer" begin + val = Int16(1.0) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) @@ -328,6 +352,12 @@ end ############################################################################################ @testset "Types - Array Integer" begin + val = Int16.([1 2 ; 3 4]) +#Default ReConstructor and flatten + val_flat, _reconstruct = flatten(val) + @test val_flat == flatten(_reconstruct, val) + @test val == unflatten(_reconstruct, val_flat) + unflattenAD(_reconstruct, val_flat) for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype)