Skip to content

Commit

Permalink
Update Unflatten Type for Simplex
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Apr 2, 2022
1 parent 289dc63 commit f8107cf
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ModelWrappers"
uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6"
authors = ["Patrick Aschermayr <[email protected]>"]
version = "0.1.12"
version = "0.1.13"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
5 changes: 3 additions & 2 deletions src/Core/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function _checkprior(prior::N) where {N<:NamedTuple}
end

############################################################################################
#=
"""
$(SIGNATURES)
Check if all keys of 'x' and 'y' match - works with Nested Tuples - and return Bool. Not exported.
Expand All @@ -97,7 +98,7 @@ function _checkkeys(x::NamedTuple{Kx,Tx}, y::NamedTuple{Ky,Ty}) where {Kx,Tx,Ky,
## Else return true
return true
end

=#
############################################################################################
"""
$(SIGNATURES)
Expand All @@ -114,7 +115,7 @@ end
function _checksampleable(constraint::S) where {S<:Distributions.Distribution}
return true
end
function _checksampleable(constraint::Vector{T}) where {T<:Real}
function _checksampleable(constraint::Vector{T}) where {T}
@inbounds @simd for iter in eachindex(constraint)
if !_checksampleable(constraint[iter])
return false
Expand Down
9 changes: 6 additions & 3 deletions src/Core/constraints/flatten/flatten_bijector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ function flatten(
=#
#!NOTE: CorrBijector seems to unconstrain to a Upper Diagonal Matrix
idx_upper = tag(x, true, false)
buffer = ones(T, size(x))
#!NOTE: Buffer should be of type R, not T, as we want same type back afterwards
buffer = ones(R, size(x))
function CorrMatrix_from_vec(x_vec::Union{<:Real,AbstractVector{<:Real}})
return Symmetric_from_flatten!(buffer, x_vec, idx_upper)
end
Expand Down Expand Up @@ -87,7 +88,8 @@ function flatten(
) where {T<:AbstractFloat,F<:FlattenTypes,R<:Real,C<:Union{Distributions.InverseWishart, Bijectors.PDBijector}}
#!NOTE: PDBijector seems to unconstrain to a Lower Diagonal Matrix
idx_upper = tag(x, false, true) #tag(x, true, true)
buffer = zeros(T, size(x))
#!NOTE: Buffer should be of type R, not T, as we want same type back afterwards
buffer = zeros(R, size(x))
function Symmetric_from_vec(x_vec::Union{<:Real,AbstractVector{<:Real}})
return Symmetric_from_flatten!(buffer, x_vec, idx_upper)
end
Expand Down Expand Up @@ -128,7 +130,8 @@ function flatten(
(3) Consequently, we can flatten in length(x)-1 dimensions, and unflatten back to length(x) by summing up elements for length(x)'s element.
This works for both constrained/unconstrained.
=#
buffer = zeros(T, length(x))
#!NOTE: Buffer should be of type R, not T, as we want same type back afterwards
buffer = zeros(R, length(x))
function unflatten_Simplex(x_vec)
return Simplex_from_flatten!(buffer, x_vec)
end
Expand Down
12 changes: 6 additions & 6 deletions src/Core/constraints/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ Fill array with elements of vec. Not exported.
return buffer
end
function ChainRulesCore.rrule(
::typeof(fill_array!), mat::AbstractMatrix{R}, v::Union{R,AbstractVector{R}}
) where {R<:Real}
::typeof(fill_array!), mat::AbstractMatrix{R}, v::Union{T,AbstractVector{T}}
) where {R<:Real,T<:Real}
# forward pass: Fill Matrix with Vector elements
L = fill_array!(mat, v)
# backward pass: Fill Vector with Matrix elements
Expand Down Expand Up @@ -102,17 +102,17 @@ Inplace version of Simplex_from_flatten. Not exported.
"""
function Simplex_from_flatten!(
buffer::AbstractVector{R}, x_vec::Union{R,AbstractVector{R}}
) where {R<:Real}
buffer::AbstractVector{R}, x_vec::Union{T,AbstractVector{T}}
) where {R<:Real,T<:Real}
@inbounds for iter in eachindex(x_vec)
buffer[iter] = x_vec[iter]
end
buffer[end] = 1.0 - sum(x_vec)
return buffer
end
function ChainRulesCore.rrule(
::typeof(Simplex_from_flatten!), p::AbstractVector{R}, v::Union{R,AbstractVector{R}}
) where {R<:Real}
::typeof(Simplex_from_flatten!), p::AbstractVector{R}, v::Union{T,AbstractVector{T}}
) where {R<:Real,T<:Real}
# forward pass: From k-1 to k dimensions
L = Simplex_from_flatten!(p, v)
# backward pass: From k to k-1 dimensions
Expand Down
4 changes: 4 additions & 0 deletions test/TestHelper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ N = 10^3
df_strict = FlattenDefault(Float64, FlattenContinuous(), UnflattenStrict())
df_AD = FlattenDefault(Float64, FlattenContinuous(), UnflattenAD())

unflattenmethods = [UnflattenStrict(), UnflattenAD()]
flattenmethods = [FlattenAll(), FlattenContinuous()]
flattentypes = [Float64, Float32]

############################################################################################
#Probabilistic Parameters - Some selected distributions
struct ProbModel <: ModelName end
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ using ModelWrappers:
_to_inv_bijector,
flatten,
constrain,
_checkkeys,
_get_val,
_get_constraint,
log_density,
Expand Down
19 changes: 19 additions & 0 deletions test/test-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,26 @@
@test _checkfinite(.1)
@test !_checkfinite(Inf)
@test _checkfinite( (.1, .2, [.2, 3.], zeros(2,3)))

@test !_checkfinite([.1, -Inf, 3.])
@test !_checkfinite([[2. 3. ; 4. 5], [.1, -Inf, 3.]])

end

@testset "Core - Checkprior" begin
@test _checkprior(Distributions.Normal())
@test !_checkprior(Inf)
@test _checkprior([Distributions.Normal(), Distributions.Normal()])
@test !_checkprior([Inf, Distributions.Normal(), Distributions.Normal()])
@test !_checkprior([[Distributions.Normal(), Distributions.Normal()], [Inf, Distributions.Normal(), Distributions.Normal()]])
end

@testset "Core - Checksampleable" begin
@test _checksampleable(Distributions.Normal())
@test !_checksampleable(Fixed())
@test _checksampleable([Distributions.Normal(), Distributions.Normal()])
@test !_checksampleable([Fixed(), Distributions.Normal(), Distributions.Normal()])
@test !_checksampleable([[Distributions.Normal(), Distributions.Normal()], [Fixed(), Distributions.Normal(), Distributions.Normal()]])
end

@testset "Core - Checkparams" begin
Expand Down Expand Up @@ -93,6 +99,13 @@ end
(a = Unconstrained(), c = Fixed(), b = [Normal(), Normal()])
)
@test _names2 == ["a", "b1", "b2"]
_names3 = ModelWrappers.paramnames(
(:a,:c,:b),
FlattenDefault(),
(a = 1., c = 2., b = [3., 4.]),
(a = Unconstrained(), c = Fixed(), b = [Normal(), Normal()])
)
@test all(_names2 .== _names3)
end

@testset "Core - paramcount" begin
Expand All @@ -101,4 +114,10 @@ end
(a = 1., c = 2., b = [3., 4.])
)
@test _vals == (1,1,2)
_vals2 = ModelWrappers.paramcount(
(:a,:c,:b),
FlattenDefault(),
(a = 1., c = 2., b = [3., 4.])
)
@test all(_vals .== _vals2)
end
43 changes: 35 additions & 8 deletions test/test-differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,38 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx))
fwd = DiffObjective(objectiveExample, tune_fwd)
rd = DiffObjective(objectiveExample, tune_rd)
zyg = DiffObjective(objectiveExample, tune_zyg)
theta_unconstrained = randn(length(modelExample))
theta_unconstrained = randn(_RNG, length(modelExample))
## Compute logdensity
log_density(objectiveExample)
theta_unconstrained2 = deepcopy(theta_unconstrained)
#!NOTE: 10th parameter in likelihood for example, so is not compiled away in Reverse Tape
theta_unconstrained2[10] = Inf
_ld = log_density(objectiveExample, theta_unconstrained2)
@test isinf(_ld.ℓθᵤ)
_ld = log_density(fwd, theta_unconstrained2)
@test isinf(_ld.ℓθᵤ)
_ld = log_density(rd, theta_unconstrained2)
@test isinf(_ld.ℓθᵤ)
_ld = log_density(zyg, theta_unconstrained2)
@test isinf(_ld.ℓθᵤ)

ld1 = log_density(fwd, theta_unconstrained)
ld2 = log_density(rd, theta_unconstrained)
ld3 = log_density(zyg, theta_unconstrained)
_ld1 = _log_density(objectiveExample, tune_fwd, theta_unconstrained)
_ld2 = _log_density(objectiveExample, tune_rd, theta_unconstrained)
_ld3 = _log_density(objectiveExample, tune_zyg, theta_unconstrained)
## Compute Diffresult
_grad = log_density_and_gradient(fwd, theta_unconstrained2)
@test isinf(_grad.ℓθᵤ)
_grad = log_density_and_gradient(rd, theta_unconstrained2)
#!TODO: Need an example where tape does not compile parameter in infi
@test isinf(_grad.ℓθᵤ)
_grad = log_density_and_gradient(zyg, theta_unconstrained2)
@test isinf(_grad.ℓθᵤ)
_grad1 = _log_density_and_gradient(objectiveExample, tune_fwd, theta_unconstrained)
_grad2 = _log_density_and_gradient(objectiveExample, tune_rd, theta_unconstrained)
_grad3 = _log_density_and_gradient(objectiveExample, tune_zyg, theta_unconstrained)
ld1 = log_density(fwd, theta_unconstrained)
ld2 = log_density(rd, theta_unconstrained)
ld3 = log_density(zyg, theta_unconstrained)
grad1 = log_density_and_gradient(fwd, theta_unconstrained)
grad2 = log_density_and_gradient(rd, theta_unconstrained)
grad3 = log_density_and_gradient(zyg, theta_unconstrained)
Expand All @@ -45,7 +69,7 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx))
ModelWrappers.update(tune_rd, objectiveExample)
ModelWrappers.update(tune_zyg, objectiveExample)
## Config DiffTune
theta_unconstrained2 = randn(length(objectiveExample))
theta_unconstrained2 = randn(_RNG, length(objectiveExample))
ModelWrappers._config(ModelWrappers.ADForward(), objectiveExample, theta_unconstrained2)
ModelWrappers._config(ModelWrappers.ADReverse(), objectiveExample, theta_unconstrained2)
ModelWrappers._config(ModelWrappers.ADReverseUntaped(), objectiveExample, theta_unconstrained2)
Expand All @@ -65,11 +89,14 @@ objectiveLowerDim = Objective(modelLowerDim, nothing)
fwd = DiffObjective(objectiveLowerDim, autodiff_fd)
rd = DiffObjective(objectiveLowerDim, autodiff_rd)
zyg = DiffObjective(objectiveLowerDim, autodiff_zyg)
theta_unconstrained = randn(length(objectiveLowerDim))
theta_unconstrained = randn(_RNG, length(objectiveLowerDim))
## Compute Diffresult
ld1 = log_density(fwd, theta_unconstrained)
ld2 = log_density(rd, theta_unconstrained)
ld3 = log_density(zyg, theta_unconstrained)
_ld1 = _log_density(objectiveLowerDim, autodiff_fd, theta_unconstrained)
_ld2 = _log_density(objectiveLowerDim, autodiff_rd, theta_unconstrained)
_ld3 = _log_density(objectiveLowerDim, autodiff_zyg, theta_unconstrained)
grad1 = log_density_and_gradient(fwd, theta_unconstrained)
grad2 = log_density_and_gradient(rd, theta_unconstrained)
grad3 = log_density_and_gradient(zyg, theta_unconstrained)
Expand Down Expand Up @@ -97,7 +124,7 @@ objectiveLowerDim = Objective(modelLowerDim, nothing)
ModelWrappers.update(autodiff_rd, objectiveLowerDim)
ModelWrappers.update(autodiff_zyg, objectiveLowerDim)
## Config DiffTune
theta_unconstrained2 = randn(length(objectiveLowerDim))
theta_unconstrained2 = randn(_RNG, length(objectiveLowerDim))
ModelWrappers._config(ModelWrappers.ADForward(), objectiveLowerDim, theta_unconstrained2)
ModelWrappers._config(ModelWrappers.ADReverse(), objectiveLowerDim, theta_unconstrained2)
ModelWrappers._config(ModelWrappers.ADReverseUntaped(), objectiveLowerDim, theta_unconstrained2)
Expand Down Expand Up @@ -144,7 +171,7 @@ end
function fun1(objective::Objective{<:ModelWrapper{M}}, θᵤ::AbstractVector{T}) where {M<:ExampleModel, T<:Real}
return zeros(size(θᵤ))
end
θᵤ = randn(length(objectiveExample))
θᵤ = randn(_RNG, length(objectiveExample))
fun1(objectiveExample, θᵤ)
@testset "AnalyticDiffTune - " begin
tune_analytic = AnalyticalDiffTune(fun1)
Expand Down
45 changes: 45 additions & 0 deletions test/test-flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,29 @@
param = _params[sym]
θ = _get_val(param)
constraint = _get_constraint(param)
## Check all flatten possibilities
for unflat in unflattenmethods
for flat in flattenmethods
for floattypes in flattentypes
#println(unflat, " ", flat, " ", floattypes)
flatdefault = FlattenDefault(;
output = floattypes,
flattentype = flat,
unflattentype = unflat
)
θ_flat, _unflatten = flatten(flatdefault, θ, constraint)
θ_unflat = _unflatten(θ_flat)
#!NOTE Do not test if FlattenContinuous and empty Integer Param struct is evaluated
if flat isa FlattenAll || θ_flat isa Vector{T} where {T<:AbstractFloat}
@test eltype(θ_flat) == floattypes
end
#!NOTE: Check types if UnflattenStrict
if unflat isa UnflattenStrict
@test typeof(θ_unflat) == typeof(θ)
end
end
end
end
@test _checkparams(param)
@test _checkfinite(θ)
@test _checkprior(constraint)
Expand Down Expand Up @@ -60,6 +83,28 @@ end
param = _params[sym]
θ = _get_val(param)
constraint = _get_constraint(param)
for unflat in unflattenmethods
for flat in flattenmethods
for floattypes in flattentypes
# println(unflat, " ", flat, " ", floattypes)
flatdefault = FlattenDefault(;
output = floattypes,
flattentype = flat,
unflattentype = unflat
)
θ_flat, _unflatten = flatten(flatdefault, θ, constraint)
θ_unflat = _unflatten(θ_flat)
#!NOTE Do not test if FlattenContinuous and empty Integer Param struct is evaluated
if flat isa FlattenAll || θ_flat isa Vector{T} where {T<:AbstractFloat}
@test eltype(θ_flat) == floattypes
end
#!NOTE: Check types if UnflattenStrict
if unflat isa UnflattenStrict
@test typeof(θ_unflat) == typeof(θ)
end
end
end
end
@test _checkparams(param)
@test _checkfinite(θ)
bij = _to_bijector(constraint)
Expand Down
16 changes: 12 additions & 4 deletions test/test-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
# Basic Functionality
_modelProb = ModelWrapper(ProbModel(), val_dist)
@testset "Models - basic functionality" begin
## Model Length accounting discrete parameter
unconstrain(_modelProb)
flatten(_modelProb)
unconstrain_flatten(_modelProb)
## Type Check 1 - Constrain/Unconstrain
theta_unconstrained_vec = randn(length(_modelProb))
theta_unconstrained = unflatten(_modelProb, theta_unconstrained_vec)
Expand All @@ -21,6 +17,18 @@ _modelProb = ModelWrapper(ProbModel(), val_dist)
## Check if densities match
@test log_prior(_modelProb) + log_abs_det_jac(_modelProb)
log_prior_with_transform(_modelProb)
## Utility functions
unconstrain(_modelProb)
flatten(_modelProb)
unconstrain_flatten(_modelProb)
simulate(_modelProb)
fill(_modelProb, _modelProb.val)
fill!(_modelProb, _modelProb.val)
subset(_modelProb, keys(_modelProb.val))
unflatten!(_modelProb, flatten(_modelProb))
unflatten_constrain!(_modelProb, unconstrain_flatten(_modelProb))
sample(_RNG, _modelProb)
sample!(_RNG, _modelProb)
end

############################################################################################
Expand Down
7 changes: 6 additions & 1 deletion test/test-objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,18 @@ end
length(objectiveExample)
ModelWrappers.paramnames(objectiveExample)
theta_unconstrained = randn(length(modelExample))
Objective(objectiveExample.model, objectiveExample.data, objectiveExample.tagged, objectiveExample.temperature)
Objective(objectiveExample.model, objectiveExample.data, objectiveExample.tagged)
Objective(objectiveExample.model, objectiveExample.data, keys(objectiveExample.tagged.parameter)[1:2])
Objective(objectiveExample.model, objectiveExample.data, keys(objectiveExample.tagged.parameter)[1])
Objective(objectiveExample.model, objectiveExample.data)

predict(_RNG, objectiveExample)
generate(_RNG, objectiveExample)
generate(_RNG, objectiveExample, ModelWrappers.UpdateTrue())
generate(_RNG, objectiveExample, ModelWrappers.UpdateFalse())
dynamics(objectiveExample)

@test abs(
(log_prior(modelExample) + log_abs_det_jac(modelExample)) -
log_prior_with_transform(modelExample),
Expand Down
9 changes: 9 additions & 0 deletions test/test-tagged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ _params = [sample(_modelProb, _targets[iter]) for iter in eachindex(_syms)]
_θ2, _ = flatten(theta_constrained2, _target.info.constraint)
@test sum(abs.(_θ1 - _θ2)) 0 atol = _TOL
## Utility functions
log_prior(_target, _model_temp.val)
θ_flat = flatten(_model_temp, _target)
unflatten(_model_temp, _target, θ_flat)
unflatten!(_model_temp, _target, θ_flat)
θ_flat_unconstrained = unconstrain_flatten(_model_temp, _target)
unflatten_constrain!(_model_temp, _target, θ_flat_unconstrained)
log_prior_with_transform(_model_temp, _target)

subset(_model_temp, _target)
ModelWrappers.length(_target)
ModelWrappers.paramnames(_target)
Expand All @@ -42,5 +50,6 @@ _params = [sample(_modelProb, _targets[iter]) for iter in eachindex(_syms)]
_model_temp.val
sample(_RNG, _model_temp, _target)
sample!(_RNG, _model_temp, _target)

end
end

2 comments on commit f8107cf

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/57816

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.13 -m "<description of version>" f8107cf9aa55b6730bd839732788395400532d43
git push origin v0.1.13

Please sign in to comment.