Skip to content

Commit

Permalink
Update Model initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Jul 13, 2022
1 parent b524b4e commit dd904b7
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ModelWrappers"
uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6"
authors = ["Patrick Aschermayr <[email protected]>"]
version = "0.2.4"
version = "0.2.5"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
2 changes: 2 additions & 0 deletions src/Models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ include("parameterinfo.jl")
include("modelwrapper.jl")
include("tagged.jl")
include("objective.jl")
include("initial.jl")

#!NOTE: Remove Soss dependency from ModelWrappers because of heavy deps. Can make separate BaytesSoss later on.
#include("_soss.jl")
############################################################################################
Expand Down
56 changes: 56 additions & 0 deletions src/Models/initial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
############################################################################################
"""
$(TYPEDEF)
Abstract method to initialize parameter for individual kernels.
# Fields
$(TYPEDFIELDS)
"""
abstract type AbstractInitialization end

"Use current model.val parameter as initial parameter"
struct NoInitialization <: AbstractInitialization end

"Sample (up to Ntrials) times from prior and check if log target distribution is finite at proposed parameter in unconstrained space."
struct PriorInitialization <: AbstractInitialization
Ntrials::Int64
function PriorInitialization(Ntrials::Integer)
return new(Ntrials)
end
end

"Use custom optimization technique for initialization."
struct OptimInitialization{T} <: AbstractInitialization
method::T
function OptimInitialization(method::T) where {T}
return new{T}(method)
end
end

############################################################################################
function (initialization::NoInitialization)(algorithm, objective::Objective)
#Check if initial parameter satisfy prior constraints
ℓθᵤ = objective(unconstrain_flatten(objective.model, objective.tagged))
@argcheck isfinite(ℓθᵤ) "Log target function at initial value not finite. Change initial parameter or sample from prior via PriorParameter"
return nothing
end

function (initialization::PriorInitialization)(algorithm, objective::Objective)
# Set initial counter
@unpack Ntrials = init
ℓθᵤ = -Inf
counter = 0
# Sample from prior until finite log target is obtained
while !isfinite(ℓθᵤ) && counter <= Ntrials
counter += 1
sample!(objective.model, objective.tagged)
ℓθᵤ = objective(unconstrain_flatten(objective.model, objective.tagged))
end
ArgCheck.@argcheck counter <= NInitial "Could find initial parameter with finite log target density. Adjust intial values, prior, or increase number of intial samples."
return nothing
end

############################################################################################
# Export
export AbstractInitialization, NoInitialization, PriorInitialization, OptimInitialization
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ using ModelWrappers:
Symmetric_from_flatten,
flatten_Simplex,
Simplex_from_flatten!,
Simplex_from_flatten
Simplex_from_flatten,
init

############################################################################################
# Include Files
Expand Down
27 changes: 27 additions & 0 deletions test/test-objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@ function (objective::Objective{<:ModelWrapper{SossBenchmark}})(θ::NamedTuple)
return lp + ll
end

@testset "Objective - No Initialization" begin
initmethod = NoInitialization()
_objective = deepcopy(obectiveSossBM)
_val = deepcopy(_objective.model.val)

initmethod(nothing, _objective)
@test _val == _objective.model.val
end

@testset "Objective - Prior Initialization" begin
initmethod = PriorInitialization(100)
_objective = deepcopy(obectiveSossBM)
_val = deepcopy(_objective.model.val)

initmethod(nothing, _objective)
@test _val != _objective.model.val
end

@testset "Objective - Prior Initialization, partially Tagged" begin
initmethod = PriorInitialization(100)
_objective = Objective(deepcopy(obectiveSossBM.model), obectiveSossBM.data, Tagged(obectiveSossBM.model, ))
_val = deepcopy(_objective.model.val)

initmethod(nothing, _objective)
@test _val != _objective.model.val
@test _val.μ == _objective.model.val.μ
end

################################################################################
#!NOTE: Remove Soss dependency from ModelWrappers and make separete BaytesSoss, so heavy deps. removed
Expand Down

2 comments on commit dd904b7

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64157

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.5 -m "<description of version>" dd904b71f8a8bdd3eb337b0da4d791752390f5ce
git push origin v0.2.5

Please sign in to comment.