Skip to content

Commit

Permalink
Update Error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Jul 12, 2022
1 parent ea2e45c commit 5f1754c
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ModelWrappers"
uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6"
authors = ["Patrick Aschermayr <[email protected]>"]
version = "0.2.2"
version = "0.2.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
63 changes: 62 additions & 1 deletion src/Differentiation/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,67 @@ function check_gradients(
)
end

############################################################################################
# Error handling
function checkfinite(θₜ::AbstractVector{T}) where {T<:Real}
return _checkfinite(θₜ)
end
function checkfinite(result::T) where {T<:ℓObjectiveResult}
return isfinite(result.ℓθᵤ) && _checkfinite(result.θᵤ) ? true : false
end
function checkfinite(
result₀::T, result::T, min_Δ::Float64=min_Δ
) where {T<:ℓObjectiveResult}
checkfinite(result) && ((result.ℓθᵤ - result₀.ℓθᵤ) > min_Δ) || return false
return true
end
function checkfinite(
ℓθ₀::R, ℓθ::R, result::T, min_Δ::Float64=min_Δ
) where {R<:Real,T<:ℓObjectiveResult}
checkfinite(result) && ((ℓθ - ℓθ₀) > min_Δ) || return false
return true
end

############################################################################################
"""
$(TYPEDEF)
Stores parameter in constrained space at which logdensity could not be evaluated.
# Fields
$(TYPEDFIELDS)
"""
struct ObjectiveError <: Exception
#!NOTE: Remove Parametric types so error message is shorter
msg::String
θ::NamedTuple
θᵤ::AbstractVector
function ObjectiveError(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real}
msg = "Internal error: leapfrog called from non-finite log density. Proposed parameter in constrained and unconstrained space:"
θ = unflatten_constrain(objective.model, objective.tagged, θᵤ)
new(msg, θ, θᵤ)
end
end

function checkfinite(objective::Objective, θᵤ::AbstractVector{T}) where {T<:Real}
ArgCheck.@argcheck checkfinite(θᵤ) ObjectiveError(objective, θᵤ)
end
function checkfinite(objective::Objective, result::T) where {T<:ℓObjectiveResult}
ArgCheck.@argcheck checkfinite(result) ObjectiveError(objective, result.θᵤ)
end
function checkfinite(
objective::Objective, result₀::T, result::T, min_Δ::Float64=min_Δ
) where {T<:ℓObjectiveResult}
ArgCheck.@argcheck checkfinite(result₀, result, min_Δ) ObjectiveError(objective, result.θᵤ)
end
function checkfinite(
objective::Objective, ℓθ₀::R, ℓθ::R, result::T, min_Δ::Float64=min_Δ
) where {R<:Real,T<:ℓObjectiveResult}
ArgCheck.@argcheck checkfinite(ℓθ₀, ℓθ, result, min_Δ) ObjectiveError(objective, result.θᵤ)
end

############################################################################################
# Export
export check_gradients
export
check_gradients,
checkfinite,
ObjectiveError
17 changes: 0 additions & 17 deletions src/Differentiation/results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,6 @@ function log_density_and_hessian(objective::L, θᵤ::AbstractVector{T}) where {
end
=#

############################################################################################
function checkfinite(result::T) where {T<:ℓObjectiveResult}
return isfinite(result.ℓθᵤ) && _checkfinite(result.θᵤ) ? true : false
end
function checkfinite(
result₀::T, result::T, min_Δ::Float64=min_Δ
) where {T<:ℓObjectiveResult}
checkfinite(result) && ((result.ℓθᵤ - result₀.ℓθᵤ) > min_Δ) || return false
return true
end
function checkfinite(
ℓθ₀::R, ℓθ::R, result::T, min_Δ::Float64=min_Δ
) where {R<:Real,T<:ℓObjectiveResult}
checkfinite(result) && ((ℓθ - ℓθ₀) > min_Δ) || return false
return true
end

############################################################################################
# Export
export
Expand Down
2 changes: 1 addition & 1 deletion src/ModelWrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using BaytesCore:

using DocStringExtensions:
DocStringExtensions, TYPEDEF, TYPEDFIELDS, FIELDS, SIGNATURES, FUNCTIONNAME
using ArgCheck: ArgCheck, @argcheck
using ArgCheck: ArgCheck, @argcheck, Exception
using UnPack: UnPack, @unpack, @pack!
using Random: Random, AbstractRNG, GLOBAL_RNG

Expand Down
16 changes: 16 additions & 0 deletions src/Models/modelwrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ function simulate(model::ModelWrapper) end

############################################################################################
# Dispatch Model struct for .Core functions

"""
$(SIGNATURES)
Show current values of Model as NamedTuple
# Examples
```julia
```
"""
function generate_showvalues(model::ModelWrapper)
return function showvalues()
return ((:Parameter, model.val), )
end
end

"""
$(SIGNATURES)
Fill 'model' values with NamedTuple 'θ'.
Expand Down
5 changes: 5 additions & 0 deletions src/Models/tagged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ end

############################################################################################
# Dispatch Tagged struct for .Core functions
function generate_showvalues(model::ModelWrapper, tagged::Tagged)
return function showvalues()
return ((:Parameter, subset(model, tagged)), )
end
end
function fill(model::ModelWrapper, tagged::Tagged, θ::NamedTuple)
return merge(model.val, subset(θ, tagged.parameter))
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Random: Random, AbstractRNG, seed!
using LinearAlgebra
using Distributions, Bijectors, DistributionsAD
using ForwardDiff, ReverseDiff, Zygote
using ArgCheck

############################################################################################
# Import Baytes Packages
Expand Down
15 changes: 15 additions & 0 deletions test/test-differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx))
_ld_inf_fault = log_density(objectiveExample, copy(theta_unconstrained))
_ld_inf_fault.θᵤ[10] = Inf

@test ModelWrappers.checkfinite(theta_unconstrained)
@test !ModelWrappers.checkfinite(theta_unconstrained2)

@test ModelWrappers.checkfinite(_ld_fin)
@test !ModelWrappers.checkfinite(_ld_inf)
@test !ModelWrappers.checkfinite(_ld_inf_fault)
Expand All @@ -28,8 +31,20 @@ objectiveExample = Objective(modelExample, (data1, data2, data3, _idx))
@test ModelWrappers.checkfinite(-Inf, 10.0, _ld_fin) #Infinite to Finite
@test !ModelWrappers.checkfinite(10.0, -Inf, _ld_fin) #Finite to Infinite
@test !ModelWrappers.checkfinite(-Inf, 10.0, _ld_inf) #Finite to Infinite


#!ToDo: Check for correct error message in case logdensity cannot be evaluated
ModelWrappers.checkfinite(objectiveExample, theta_unconstrained)
ModelWrappers.checkfinite(objectiveExample, _ld_fin)
ModelWrappers.checkfinite(objectiveExample, _ld_inf, _ld_fin)
ModelWrappers.checkfinite(objectiveExample, -Inf, 10.0, _ld_fin)

err = ObjectiveError(objectiveExample, theta_unconstrained)
@test isa(err, ArgCheck.Exception)

end


@testset "AutoDiffContainer - Log Objective AutoDiff compatibility - Vectorized Model" begin
## Assign DiffTune
tune_fwd = AutomaticDiffTune(:ForwardDiff, objectiveExample)
Expand Down

2 comments on commit 5f1754c

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64117

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.3 -m "<description of version>" 5f1754cd5c6c6c6300186623bbc75d31a08da86a
git push origin v0.2.3

Please sign in to comment.