Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to pass Optim.Options to bat_findmode #423

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ jobs:
- uses: codecov/codecov-action@v4
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
with:
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}
file: lcov.info
docs:
name: Documentation
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*.jl.mem
.ipynb_checkpoints
Manifest.toml
.vscode/settings.json
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"
Expand All @@ -82,10 +84,12 @@ BATFoldsExt = ["Folds", "Transducers"]
BATHDF5Ext = "HDF5"
BATNestedSamplersExt = "NestedSamplers"
BATOptimExt = "Optim"
BATOptimizationExt = ["Optimization", "ADTypes"]
BATPlotsExt = "Plots"
BATUltraNestExt = "UltraNest"

[compat]
ADTypes = "0.1, 0.2"
Accessors = "0.1"
Adapt = "3, 4"
AdvancedHMC = "0.5"
Expand Down Expand Up @@ -132,6 +136,7 @@ Measurements = "2"
NamedArrays = "0.9, 0.10"
NestedSamplers = "0.8"
Optim = "0.19,0.20, 0.21, 0.22, 1"
Optimization = "3"
PDMats = "0.9, 0.10, 0.11"
ParallelProcessingTools = "0.4"
Parameters = "0.12"
Expand All @@ -158,12 +163,14 @@ ZygoteRules = "0.2"
julia = "1.6"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"
16 changes: 16 additions & 0 deletions docs/src/list_of_algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,22 @@ bat_findmode(target, OptimAlg(optalg = Optim.LBFGS()))
Requires the [Optim](https://github.com/JuliaNLSolvers/Optim.jl) Julia package to be loaded explicitly.


### Optimization.jl Optimization Algorithms

BAT mode finding algorithm type: [`OptimizationAlg`](@ref).

```julia
using OptimizationOptimJL

alg = OptimizationAlg(;
optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10),
maxiters=200,
kwargs=(f_calls_limit=50,)
)
bat_findmode(target, alg)
```
Requires one of the [Optimization.jl](https://github.com/SciML/Optimization.jl) packages to be loaded explicitly.

### Maximum Sample Estimator

BAT mode finding algorithm type: [`MaxDensitySearch`](@ref)
Expand Down
1 change: 1 addition & 0 deletions docs/src/stable_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ MetropolisHastings
MHProposalDistTuning
ModeAsDefined
OptimAlg
OptimizationAlg
OrderedResampling
PosteriorMeasure
PriorSubstitution
Expand Down
116 changes: 116 additions & 0 deletions examples/dev-internal/test_findmode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
using BAT
using Optim

posterior = BAT.example_posterior()

optalg = OptimAlg(;
optalg = Optim.NelderMead(parameters=Optim.FixedParameters()),
maxiters=200,
kwargs = (f_calls_limit=100,),
)

my_mode = bat_findmode(posterior, optalg)

fieldnames(typeof(my_mode.info.res))


using BAT
#using Optim
#using Optimization
using OptimizationOptimJL

using InverseFunctions, FunctionChains, DensityInterface


posterior = BAT.example_posterior()
optalg = OptimizationAlg(; optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), maxiters=200, kwargs=(f_calls_limit=500,))
my_result = bat_findmode(posterior, optalg)

a = my_result.info

@test a.cache.solver_args.maxiters == 500

dump(a.alg)

fieldnames(typeof(a.cache.solver_args))

fieldnames(typeof(a.original.method))

my_mode.info.original




# Define a NamedTuple with keyword arguments
nt = (a=1, b=2)

# Define a function that accepts keyword arguments
function my_function(; a=0, b=0, c=0)
println("a = $a")
println("b = $b")
println("c = $c")
end

# Call the function and unpack the NamedTuple
my_function(; nt...)



















context = get_batcontext()
target = posterior
transformed_density, trafo = BAT.transform_and_unshape(PriorToGaussian(), target, context)
inv_trafo = inverse(trafo)
initalg = BAT.apply_trafo_to_init(trafo, InitFromTarget())
x_init = collect(bat_initval(transformed_density, initalg, context).result)

f = fchain(inv_trafo, logdensityof(target), -)
f2 = (x, p) -> f(x)


optimization_function = Optimization.OptimizationFunction(f2, Optimization.SciMLBase.NoAD())
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)
optimization_result = Optimization.solve(optimization_problem,OptimizationOptimJL.NelderMead())


optalg = OptimizationAlg(;optalg = OptimizationOptimJL.NelderMead())
my_mode = bat_findmode(posterior, optalg)

my_mode.info.original
fieldnames(typeof(my_mode.info))

rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
f = rosenbrock


using AutoDiffOperators

b = Optimization.SciMLBase.NoAD()
supertype(typeof(b))

adm = ADModule(:ForwardDiff)

adsel = BAT.get_adselector(context)
supertype(typeof(adsel))



adm2 = convert_ad(ADTypes.AbstractADType, adm)
ADTypes.AutoForwardDiff()

optimization_function = Optimization.OptimizationFunction(f2, adm2)
27 changes: 21 additions & 6 deletions ext/BATOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:DEFAULT_OPTALG}) = Optim.
BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:NELDERMEAD_ALG}) = Optim.NelderMead()
BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:LBFGS_ALG}) = Optim.LBFGS()


struct NLSolversFG!{F,AD} <: Function
f::F
ad::AD
Expand Down Expand Up @@ -59,6 +58,21 @@ function (fg!::NLSolversFG!)(::Nothing, grad_f::AbstractVector{<:Real}, x::Abstr
end


function convert_options(algorithm::OptimAlg)
if !isnan(algorithm.abstol)
@warn "The option 'abstol=$(algorithm.abstol)' is not used for this algorithm."
end

kwargs = algorithm.kwargs

algopts = (; iterations = algorithm.maxiters, time_limit = algorithm.maxtime, f_tol = algorithm.reltol,)
algopts = (; algopts..., kwargs...)
algopts = (; algopts..., store_trace = true, extended_trace=true)

return Optim.Options(; algopts...)
end


function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context::BATContext)
transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context)
inv_trafo = inverse(trafo)
Expand All @@ -68,7 +82,10 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
optim_result = _optim_minimize(f, x_init, algorithm.optalg, context)

opts = convert_options(algorithm)

optim_result = _optim_minimize(f, x_init, algorithm.optalg, opts, context)
r_optim = Optim.MaximizationWrapper(optim_result)
transformed_mode = Optim.minimizer(r_optim.res)
result_mode = inv_trafo(transformed_mode)
Expand All @@ -80,18 +97,16 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context
(result = result_mode, result_trafo = transformed_mode, trafo = trafo, #=trace_trafo = trace_trafo,=# info = r_optim)
end

function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.ZerothOrderOptimizer, ::BATContext)
opts = Optim.Options(store_trace = true, extended_trace=true)
function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.ZerothOrderOptimizer, opts::Optim.Options, ::BATContext)
_optim_optimize(f, x_init, algorithm, opts)
end

function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.FirstOrderOptimizer, context::BATContext)
function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.FirstOrderOptimizer, opts::Optim.Options, context::BATContext)
adsel = get_adselector(context)
if adsel isa _NoADSelected
throw(ErrorException("$(nameof(typeof(algorithm))) requires an ADSelector to be specified in the BAT context"))
end
fg! = NLSolversFG!(f, adsel)
opts = Optim.Options(store_trace = true, extended_trace=true)
_optim_optimize(Optim.only_fg!(fg!), x_init, algorithm, opts)
end

Expand Down
72 changes: 72 additions & 0 deletions ext/BATOptimizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).

module BATOptimizationExt

@static if isdefined(Base, :get_extension)
import Optimization
else
import ..Optimization
end

using BAT
BAT.pkgext(::Val{:Optimization}) = BAT.PackageExtension{:Optimization}()


using Random
using DensityInterface, ChangesOfVariables, InverseFunctions, FunctionChains
using HeterogeneousComputing, AutoDiffOperators
using StructArrays, ArraysOfArrays, ADTypes

using BAT: AnyMeasureOrDensity, AbstractMeasureOrDensity

using BAT: get_context, get_adselector, _NoADSelected
using BAT: bat_initval, transform_and_unshape, apply_trafo_to_init
using BAT: negative


AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg

BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()


function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
adm = convert_ad(ADTypes.AbstractADType, adsel)
optimization_function = Optimization.OptimizationFunction(f, adm)
return optimization_function
end

function build_optimizationfunction(f, adsel::BAT._NoADSelected)
optimization_function = Optimization.OptimizationFunction(f)
return optimization_function
end


function BAT.bat_findmode_impl(target::AnyMeasureOrDensity, algorithm::OptimizationAlg, context::BATContext)
transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context)
inv_trafo = inverse(trafo)

initalg = apply_trafo_to_init(trafo, algorithm.init)
x_init = collect(bat_initval(transformed_density, initalg, context).result)

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
target_f = (x, p) -> f(x)

adsel = get_adselector(context)

optimization_function = build_optimizationfunction(target_f, adsel)
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)

algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)

transformed_mode = optimization_result.u
result_mode = inv_trafo(transformed_mode)

(result = result_mode, result_trafo = transformed_mode, trafo = trafo, info = optimization_result)
end



end # module BATOptimizationExt
1 change: 1 addition & 0 deletions src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ function __init__()
@require HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" include("../ext/BATHDF5Ext.jl")
@require NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" include("../ext/BATNestedSamplersExt.jl")
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("../ext/BATOptimExt.jl")
@require Optimization = "429524aa-4258-5aef-a3af-852621145aeb" @require ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" include("../ext/BATOptimizationExt.jl")
@require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("../ext/BATPlotsExt.jl")
@require UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" include("../ext/BATUltraNestExt.jl")
end
Expand Down
1 change: 1 addition & 0 deletions src/extdefs/extdefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ include("advancedhmc_defs.jl")
include("cuba_defs.jl")
include("nestedsamplers_defs.jl")
include("optim_defs.jl")
include("optimization_defs.jl")
include("ultranest_defs.jl")
5 changes: 5 additions & 0 deletions src/extdefs/optim_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@ $(TYPEDFIELDS)
optalg::ALG = ext_default(pkgext(Val(:Optim)), Val(:DEFAULT_OPTALG))
trafo::TR = PriorToGaussian()
init::IA = InitFromTarget()
maxiters::Int = 1_000
maxtime::Float64 = NaN
abstol::Float64 = NaN
reltol::Float64 = 0.0
kwargs::NamedTuple = (;)
end
export OptimAlg
Loading
Loading