From b524b4e266f4f9bd09f186be69145f824915d006 Mon Sep 17 00:00:00 2001
From: Patrick Aschermayr
Date: Wed, 13 Jul 2022 10:05:32 +0100
Subject: [PATCH] Update flatten
---
Project.toml | 2 +-
src/Core/flatten/construct.jl | 8 ++++++++
src/Differentiation/checks.jl | 13 +++++++------
src/Models/modelwrapper.jl | 3 ++-
src/Models/tagged.jl | 3 ++-
test/test-differentiation.jl | 2 +-
test/test-flatten.jl | 3 +++
test/test-flatten/nested.jl | 18 ++++++++++++++++++
test/test-flatten/types.jl | 32 +++++++++++++++++++++++++++++++-
9 files changed, 73 insertions(+), 11 deletions(-)
diff --git a/Project.toml b/Project.toml
index a3a985f..9daba58 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "ModelWrappers"
uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6"
authors = ["Patrick Aschermayr "]
-version = "0.2.3"
+version = "0.2.4"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
diff --git a/src/Core/flatten/construct.jl b/src/Core/flatten/construct.jl
index 4d188a8..e68f75d 100644
--- a/src/Core/flatten/construct.jl
+++ b/src/Core/flatten/construct.jl
@@ -162,6 +162,10 @@ function flatten end
function flatten(constructor::ReConstructor, x)
return constructor.flatten.strict(x)
end
+function flatten(x)
+ constructor = ReConstructor(x)
+ return flatten(constructor, x), constructor
+end
"""
$(FUNCTIONNAME)(x )
@@ -175,6 +179,10 @@ function flattenAD end
function flattenAD(constructor::ReConstructor, x)
return constructor.flatten.flexible(x)
end
+function flattenAD(x)
+ constructor = ReConstructor(x)
+ return flattenAD(constructor, x), constructor
+end
"""
$(FUNCTIONNAME)(x )
diff --git a/src/Differentiation/checks.jl b/src/Differentiation/checks.jl
index 1a9ddfa..b1f400e 100644
--- a/src/Differentiation/checks.jl
+++ b/src/Differentiation/checks.jl
@@ -96,30 +96,31 @@ $(TYPEDFIELDS)
struct ObjectiveError <: Exception
#!NOTE: Remove Parametric types so error message is shorter
msg::String
+ ℓθᵤ::Real
θ::NamedTuple
θᵤ::AbstractVector
- function ObjectiveError(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real}
+ function ObjectiveError(objective::Objective, ℓθᵤ::S, θᵤ::AbstractVector{T}) where {S<:Real, T<:Real}
msg = "Internal error: leapfrog called from non-finite log density. Proposed parameter in constrained and unconstrained space:"
θ = unflatten_constrain(objective.model, objective.tagged, θᵤ)
- new(msg, θ, θᵤ)
+ new(msg, ℓθᵤ, θ, θᵤ)
end
end
function checkfinite(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real}
- ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, θᵤ)
+ ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, NaN, θᵤ)
end
function checkfinite(objective::Objective, result::T) where {T<:ℓObjectiveResult}
- ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.θᵤ)
+ ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ)
end
function checkfinite(
objective::Objective, result₀::T, result::T, min_Δ::Float64=min_Δ
) where {T<:ℓObjectiveResult}
- ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.θᵤ)
+ ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ)
end
function checkfinite(
objective::Objective, ℓθ₀::R, ℓθ::R, result::T, min_Δ::Float64=min_Δ
) where {R<:Real,T<:ℓObjectiveResult}
- ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.θᵤ)
+ ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ)
end
############################################################################################
diff --git a/src/Models/modelwrapper.jl b/src/Models/modelwrapper.jl
index 496e6db..793c469 100644
--- a/src/Models/modelwrapper.jl
+++ b/src/Models/modelwrapper.jl
@@ -241,7 +241,8 @@ Inplace version of [`sample`](@ref).
"""
function sample!(_rng::Random.AbstractRNG, model::ModelWrapper)
- ArgCheck.@argcheck _checkprior(model.info.constraint) "For inplace sample version, all constraints need to be a Distribution."
+ #!NOTE: Check no longer needed, as Fixed Tags just return current value
+# ArgCheck.@argcheck _checkprior(model.info.constraint) "For inplace sample version, all constraints need to be a Distribution."
model.val = sample(_rng, model)
return nothing
end
diff --git a/src/Models/tagged.jl b/src/Models/tagged.jl
index 8739874..c0f262a 100644
--- a/src/Models/tagged.jl
+++ b/src/Models/tagged.jl
@@ -104,7 +104,8 @@ end
sample(model::ModelWrapper, tagged::Tagged) = sample(Random.GLOBAL_RNG, model, tagged)
function sample!(_rng::Random.AbstractRNG, model::ModelWrapper, tagged::Tagged)
- ArgCheck.@argcheck _checkprior(subset(tagged.info.constraint, tagged.parameter)) "For inplace sample version, all constraints need to be a Distribution."
+ #!NOTE: Check no longer needed, as Fixed Tags just return current value
+# ArgCheck.@argcheck _checkprior(subset(tagged.info.constraint, tagged.parameter)) "For inplace sample version, all constraints need to be a Distribution."
model.val = sample(_rng, model, tagged)
return nothing
end
diff --git a/test/test-differentiation.jl b/test/test-differentiation.jl
index 3c4afa8..3079e2a 100644
--- a/test/test-differentiation.jl
+++ b/test/test-differentiation.jl
@@ -39,7 +39,7 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx))
ModelWrappers.checkfinite(objectiveExample, _ld_inf, _ld_fin)
ModelWrappers.checkfinite(objectiveExample, -Inf, 10.0, _ld_fin)
- err = ObjectiveError(objectiveExample, theta_unconstrained)
+ err = ObjectiveError(objectiveExample, -Inf, theta_unconstrained)
@test isa(err, ArgCheck.Exception)
end
diff --git a/test/test-flatten.jl b/test/test-flatten.jl
index 25d2daa..0fca2c3 100644
--- a/test/test-flatten.jl
+++ b/test/test-flatten.jl
@@ -21,6 +21,8 @@
θ_flat = _flatten(θ)
θ_unflat = _unflatten(θ_flat)
+
+
#!NOTE Do not test if FlattenContinuous and empty Integer Param struct is evaluated
if θ_flat isa Vector{T} where {T<:AbstractFloat}
if unflat isa UnflattenStrict
@@ -87,6 +89,7 @@ end
param = _params[sym]
θ = _get_val(param)
constraint = _get_constraint(param)
+ ## Check all flatten possibilities
for unflat in unflattenmethods
for flat in flattentypes
for floattypes in outputtypes
diff --git a/test/test-flatten/nested.jl b/test/test-flatten/nested.jl
index e2618d6..128a4fc 100644
--- a/test/test-flatten/nested.jl
+++ b/test/test-flatten/nested.jl
@@ -3,6 +3,12 @@
############################################################################################
@testset "Nested - AbstractArray" begin
+ val = [1., [2, 3], [4. 5. ; 6. 7.], 8., [9., 10.]]
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
@@ -95,6 +101,12 @@ end
############################################################################################
@testset "Nested - Tuple" begin
+ val = (1., [2, 3], [4. 5. ; 6. 7.], 8, [9., 10.])
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
@@ -187,6 +199,12 @@ end
############################################################################################
@testset "Nested - NamedTuple" begin
+ val = (a = Float16(1.0), b = [2, 3], c = [4. 5. ; 6. 7.], d = 8, e = [9., 10.], f = (g = (h = 3.)))
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
diff --git a/test/test-flatten/types.jl b/test/test-flatten/types.jl
index df384a1..9dc8d59 100644
--- a/test/test-flatten/types.jl
+++ b/test/test-flatten/types.jl
@@ -3,10 +3,16 @@
############################################################################################
@testset "Types - Float" begin
+ val = Float16(1.0)
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
- flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = Float16(1.0)
+ flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
ReConstructor(val)
reconstruct = ReConstructor(flatdefault, val)
# Flatten
@@ -94,6 +100,12 @@ end
############################################################################################
@testset "Types - Vector Float" begin
+ val = Float16.([1., 2.])
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
@@ -196,6 +208,12 @@ end
############################################################################################
@testset "Types - Array Float" begin
+ val = Float16.([1. 0.3 ; .3 1.0])
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
@@ -298,6 +316,12 @@ end
############################################################################################
@testset "Types - Integer" begin
+ val = Int16(1.0)
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
@@ -328,6 +352,12 @@ end
############################################################################################
@testset "Types - Array Integer" begin
+ val = Int16.([1 2 ; 3 4])
+#Default ReConstructor and flatten
+ val_flat, _reconstruct = flatten(val)
+ @test val_flat == flatten(_reconstruct, val)
+ @test val == unflatten(_reconstruct, val_flat)
+ unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)