Skip to content

Commit

Permalink
add wrapper for Optimization.jl, allow to pass kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
Cornelius-G committed Oct 6, 2023
1 parent d9957d9 commit fb70273
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 5 deletions.
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"
Expand All @@ -80,10 +82,12 @@ BATFoldsExt = ["Folds", "Transducers"]
BATHDF5Ext = "HDF5"
BATNestedSamplersExt = "NestedSamplers"
BATOptimExt = "Optim"
BATOptimizationExt = ["Optimization", "ADTypes"]
BATPlotsExt = "Plots"
BATUltraNestExt = "UltraNest"

[compat]
ADTypes = "0.1, 0.2"
Accessors = "0.1"
AdvancedHMC = "0.5"
AffineMaps = "0.2.3"
Expand Down Expand Up @@ -124,6 +128,7 @@ Measurements = "2"
NamedArrays = "0.9, 0.10"
NestedSamplers = "0.8"
Optim = "0.19,0.20, 0.21, 0.22, 1"
Optimization = "3"
PDMats = "0.9, 0.10, 0.11"
ParallelProcessingTools = "0.4"
Parameters = "0.12"
Expand All @@ -147,12 +152,14 @@ ZygoteRules = "0.2"
julia = "1.6"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"
22 changes: 19 additions & 3 deletions ext/BATOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:DEFAULT_OPTALG}) = Optim.
BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:NELDERMEAD_ALG}) = Optim.NelderMead()
BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:LBFGS_ALG}) = Optim.LBFGS()

BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:DEFAULT_OPTS}) = Optim.Options(store_trace = true, extended_trace=true)

struct NLSolversFG!{F,AD} <: Function
f::F
ad::AD
Expand Down Expand Up @@ -61,6 +59,21 @@ function (fg!::NLSolversFG!)(::Nothing, grad_f::AbstractVector{<:Real}, x::Abstr
end


function convert_options(algorithm::OptimAlg)
if algorithm.abstol != NaN
@warn "The option 'abstol' is not used for this algorithm."
end

kwargs = algorithm.kwargs

algopts = (; iterations = algorithm.maxiters, time_limit = algorithm.maxtime, f_tol = algorithm.reltol,)
algopts = (; algopts..., kwargs...)
algopts = (; algopts..., store_trace = true, extended_trace=true)

return Optim.Options(; algopts...)
end


function BAT.bat_findmode_impl(target::AnyMeasureOrDensity, algorithm::OptimAlg, context::BATContext)
transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context)
inv_trafo = inverse(trafo)
Expand All @@ -70,7 +83,10 @@ function BAT.bat_findmode_impl(target::AnyMeasureOrDensity, algorithm::OptimAlg,

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
optim_result = _optim_minimize(f, x_init, algorithm.optalg, algorithm.options, context)

opts = convert_options(algorithm)

optim_result = _optim_minimize(f, x_init, algorithm.optalg, opts, context)
r_optim = Optim.MaximizationWrapper(optim_result)
transformed_mode = Optim.minimizer(r_optim.res)
result_mode = inv_trafo(transformed_mode)
Expand Down
72 changes: 72 additions & 0 deletions ext/BATOptimizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).

module BATOptimizationExt

@static if isdefined(Base, :get_extension)
import Optimization
else
import ..Optimization
end

using BAT
BAT.pkgext(::Val{:Optimization}) = BAT.PackageExtension{:Optimization}()

Check warning on line 12 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L12

Added line #L12 was not covered by tests


using Random
using DensityInterface, ChangesOfVariables, InverseFunctions, FunctionChains
using HeterogeneousComputing, AutoDiffOperators
using StructArrays, ArraysOfArrays, ADTypes

using BAT: AnyMeasureOrDensity, AbstractMeasureOrDensity

using BAT: get_context, get_adselector, _NoADSelected
using BAT: bat_initval, transform_and_unshape, apply_trafo_to_init
using BAT: negative


AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg

Check warning on line 28 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L27-L28

Added lines #L27 - L28 were not covered by tests

BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()

Check warning on line 30 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L30

Added line #L30 was not covered by tests


function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
adm = convert_ad(ADTypes.AbstractADType, adsel)
optimization_function = Optimization.OptimizationFunction(f, adm)
return optimization_function

Check warning on line 36 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L33-L36

Added lines #L33 - L36 were not covered by tests
end

function build_optimizationfunction(f, adsel::BAT._NoADSelected)
optimization_function = Optimization.OptimizationFunction(f)
return optimization_function

Check warning on line 41 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L39-L41

Added lines #L39 - L41 were not covered by tests
end


function BAT.bat_findmode_impl(target::AnyMeasureOrDensity, algorithm::OptimizationAlg, context::BATContext)
transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context)
inv_trafo = inverse(trafo)

Check warning on line 47 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L45-L47

Added lines #L45 - L47 were not covered by tests

initalg = apply_trafo_to_init(trafo, algorithm.init)
x_init = collect(bat_initval(transformed_density, initalg, context).result)

Check warning on line 50 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L49-L50

Added lines #L49 - L50 were not covered by tests

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
f2 = (x, p) -> f(x)

Check warning on line 54 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L53-L54

Added lines #L53 - L54 were not covered by tests

adsel = get_adselector(context)

Check warning on line 56 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L56

Added line #L56 was not covered by tests

optimization_function = build_optimizationfunction(f2, adsel)
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)

Check warning on line 59 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L58-L59

Added lines #L58 - L59 were not covered by tests

algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)

Check warning on line 62 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L61-L62

Added lines #L61 - L62 were not covered by tests

transformed_mode = optimization_result.u
result_mode = inv_trafo(transformed_mode)

Check warning on line 65 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L64-L65

Added lines #L64 - L65 were not covered by tests

(result = result_mode, result_trafo = transformed_mode, trafo = trafo, info = optimization_result)

Check warning on line 67 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L67

Added line #L67 was not covered by tests
end



end # module BATOptimizationExt
1 change: 1 addition & 0 deletions src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ function __init__()
@require HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" include("../ext/BATHDF5Ext.jl")
@require NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" include("../ext/BATNestedSamplersExt.jl")
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("../ext/BATOptimExt.jl")
@require Optimization = "429524aa-4258-5aef-a3af-852621145aeb" @require ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" include("../ext/BATOptimizationExt.jl")
@require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("../ext/BATPlotsExt.jl")
@require UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" include("../ext/BATUltraNestExt.jl")
end
Expand Down
1 change: 1 addition & 0 deletions src/extdefs/extdefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ include("advancedhmc_defs.jl")
include("cuba_defs.jl")
include("nestedsamplers_defs.jl")
include("optim_defs.jl")
include("optimization_defs.jl")
include("ultranest_defs.jl")
7 changes: 5 additions & 2 deletions src/extdefs/optim_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ $(TYPEDFIELDS)
"""
@with_kw struct OptimAlg{
ALG,
OPTS,
TR<:AbstractTransformTarget,
IA<:InitvalAlgorithm
} <: AbstractModeEstimator
optalg::ALG = ext_default(pkgext(Val(:Optim)), Val(:DEFAULT_OPTALG))
options::OPTS = ext_default(pkgext(Val(:Optim)), Val(:DEFAULT_OPTS))
trafo::TR = PriorToGaussian()
init::IA = InitFromTarget()
maxiters::Int = 1_000
maxtime::Float64 = NaN
abstol::Float64 = NaN
reltol::Float64 = 0.0
kwargs::NamedTuple = (;)
end
export OptimAlg
44 changes: 44 additions & 0 deletions src/extdefs/optimization_defs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).


"""
OptimizationAlg
Selects an optimization algorithm from the
[Optimization.jl](https://github.com/SciML/Optimization.jl)
package.
Note that when using first order algorithms like `OptimizationOptimJL.LBFGS`, your
[`BATContext`](@ref) needs to include an `ADSelector` that specifies
which automatic differentiation backend should be used.
Constructors:
* ```$(FUNCTIONNAME)(; fields...)```
`optalg` must be an `Optimization.AbstractOptimizer`.
Fields:
$(TYPEDFIELDS)
!!! note
This algorithm is only available if the Optimization package is loaded (e.g. via
`import Optimization`.
"""
@with_kw struct OptimizationAlg{
ALG,
TR<:AbstractTransformTarget,
IA<:InitvalAlgorithm
} <: AbstractModeEstimator
optalg::ALG = ext_default(pkgext(Val(:Optimization)), Val(:DEFAULT_OPTALG))
trafo::TR = PriorToGaussian()
init::IA = InitFromTarget()
maxiters::Int64 = 1_000
maxtime::Float64 = NaN
abstol::Float64 = NaN
reltol::Float64 = 0.0
kwargs::NamedTuple = (;)
end
export OptimizationAlg

0 comments on commit fb70273

Please sign in to comment.