Skip to content

Commit

Permalink
Update flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Jul 13, 2022
1 parent 038e83c commit b524b4e
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ModelWrappers"
uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6"
authors = ["Patrick Aschermayr <[email protected]>"]
version = "0.2.3"
version = "0.2.4"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
8 changes: 8 additions & 0 deletions src/Core/flatten/construct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ function flatten end
function flatten(constructor::ReConstructor, x)
return constructor.flatten.strict(x)
end
function flatten(x)
constructor = ReConstructor(x)
return flatten(constructor, x), constructor
end

"""
$(FUNCTIONNAME)(x )
Expand All @@ -175,6 +179,10 @@ function flattenAD end
function flattenAD(constructor::ReConstructor, x)
return constructor.flatten.flexible(x)
end
function flattenAD(x)
constructor = ReConstructor(x)
return flattenAD(constructor, x), constructor
end

"""
$(FUNCTIONNAME)(x )
Expand Down
13 changes: 7 additions & 6 deletions src/Differentiation/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,31 @@ $(TYPEDFIELDS)
struct ObjectiveError <: Exception
#!NOTE: Remove Parametric types so error message is shorter
msg::String
ℓθᵤ::Real
θ::NamedTuple
θᵤ::AbstractVector
function ObjectiveError(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real}
function ObjectiveError(objective::Objective, ℓθᵤ::S, θᵤ::AbstractVector{T}) where {S<:Real, T<:Real}
msg = "Internal error: leapfrog called from non-finite log density. Proposed parameter in constrained and unconstrained space:"
θ = unflatten_constrain(objective.model, objective.tagged, θᵤ)
new(msg, θ, θᵤ)
new(msg, ℓθᵤ, θ, θᵤ)
end
end

function checkfinite(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real}
ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, θᵤ)
ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, NaN, θᵤ)
end
function checkfinite(objective::Objective, result::T) where {T<:ℓObjectiveResult}
ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.θᵤ)
ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ)
end
function checkfinite(
objective::Objective, result₀::T, result::T, min_Δ::Float64=min_Δ
) where {T<:ℓObjectiveResult}
ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.θᵤ)
ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ)
end
function checkfinite(
objective::Objective, ℓθ₀::R, ℓθ::R, result::T, min_Δ::Float64=min_Δ
) where {R<:Real,T<:ℓObjectiveResult}
ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.θᵤ)
ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.ℓθᵤ, result.θᵤ)
end

############################################################################################
Expand Down
3 changes: 2 additions & 1 deletion src/Models/modelwrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ Inplace version of [`sample`](@ref).
"""
function sample!(_rng::Random.AbstractRNG, model::ModelWrapper)
ArgCheck.@argcheck _checkprior(model.info.constraint) "For inplace sample version, all constraints need to be a Distribution."
#!NOTE: Check no longer needed, as Fixed Tags just return current value
# ArgCheck.@argcheck _checkprior(model.info.constraint) "For inplace sample version, all constraints need to be a Distribution."
model.val = sample(_rng, model)
return nothing
end
Expand Down
3 changes: 2 additions & 1 deletion src/Models/tagged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ end
sample(model::ModelWrapper, tagged::Tagged) = sample(Random.GLOBAL_RNG, model, tagged)

function sample!(_rng::Random.AbstractRNG, model::ModelWrapper, tagged::Tagged)
ArgCheck.@argcheck _checkprior(subset(tagged.info.constraint, tagged.parameter)) "For inplace sample version, all constraints need to be a Distribution."
#!NOTE: Check no longer needed, as Fixed Tags just return current value
# ArgCheck.@argcheck _checkprior(subset(tagged.info.constraint, tagged.parameter)) "For inplace sample version, all constraints need to be a Distribution."
model.val = sample(_rng, model, tagged)
return nothing
end
Expand Down
2 changes: 1 addition & 1 deletion test/test-differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx))
ModelWrappers.checkfinite(objectiveExample, _ld_inf, _ld_fin)
ModelWrappers.checkfinite(objectiveExample, -Inf, 10.0, _ld_fin)

err = ObjectiveError(objectiveExample, theta_unconstrained)
err = ObjectiveError(objectiveExample, -Inf, theta_unconstrained)
@test isa(err, ArgCheck.Exception)

end
Expand Down
3 changes: 3 additions & 0 deletions test/test-flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
θ_flat = _flatten(θ)
θ_unflat = _unflatten(θ_flat)



#!NOTE Do not test if FlattenContinuous and empty Integer Param struct is evaluated
if θ_flat isa Vector{T} where {T<:AbstractFloat}
if unflat isa UnflattenStrict
Expand Down Expand Up @@ -87,6 +89,7 @@ end
param = _params[sym]
θ = _get_val(param)
constraint = _get_constraint(param)
## Check all flatten possibilities
for unflat in unflattenmethods
for flat in flattentypes
for floattypes in outputtypes
Expand Down
18 changes: 18 additions & 0 deletions test/test-flatten/nested.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

############################################################################################
@testset "Nested - AbstractArray" begin
val = [1., [2, 3], [4. 5. ; 6. 7.], 8., [9., 10.]]
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down Expand Up @@ -95,6 +101,12 @@ end

############################################################################################
@testset "Nested - Tuple" begin
val = (1., [2, 3], [4. 5. ; 6. 7.], 8, [9., 10.])
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down Expand Up @@ -187,6 +199,12 @@ end

############################################################################################
@testset "Nested - NamedTuple" begin
val = (a = Float16(1.0), b = [2, 3], c = [4. 5. ; 6. 7.], d = 8, e = [9., 10.], f = (g = (h = 3.)))
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down
32 changes: 31 additions & 1 deletion test/test-flatten/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@

############################################################################################
@testset "Types - Float" begin
val = Float16(1.0)
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
val = Float16(1.0)
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
ReConstructor(val)
reconstruct = ReConstructor(flatdefault, val)
# Flatten
Expand Down Expand Up @@ -94,6 +100,12 @@ end

############################################################################################
@testset "Types - Vector Float" begin
val = Float16.([1., 2.])
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down Expand Up @@ -196,6 +208,12 @@ end

############################################################################################
@testset "Types - Array Float" begin
val = Float16.([1. 0.3 ; .3 1.0])
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down Expand Up @@ -298,6 +316,12 @@ end

############################################################################################
@testset "Types - Integer" begin
val = Int16(1.0)
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down Expand Up @@ -328,6 +352,12 @@ end

############################################################################################
@testset "Types - Array Integer" begin
val = Int16.([1 2 ; 3 4])
#Default ReConstructor and flatten
val_flat, _reconstruct = flatten(val)
@test val_flat == flatten(_reconstruct, val)
@test val == unflatten(_reconstruct, val_flat)
unflattenAD(_reconstruct, val_flat)
for output in outputtypes
for flattentype in flattentypes
flatdefault = FlattenDefault(; output = output, flattentype = flattentype)
Expand Down

2 comments on commit b524b4e

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64145

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.4 -m "<description of version>" b524b4e266f4f9bd09f186be69145f824915d006
git push origin v0.2.4

Please sign in to comment.