Skip to content

Commit

Permalink
Introduce Super struct 'MCMCState' for MCMC sampling states, refactor…
Browse files Browse the repository at this point in the history
… tuning
  • Loading branch information
Micki-D committed Sep 27, 2024
1 parent a084518 commit f2131ac
Show file tree
Hide file tree
Showing 23 changed files with 590 additions and 460 deletions.
2 changes: 1 addition & 1 deletion docs/src/internal_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ BAT.FullMeasureTransform
BAT.LFDensity
BAT.LFDensityWithGrad
BAT.LogDVal
BAT.MCMCState
BAT.MCMCChainState
BAT.MCMCSampleGenerator
BAT.MeasureLike
BAT.NoWhitening
Expand Down
8 changes: 4 additions & 4 deletions examples/dev-internal/space_transformations_examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using BAT
using AdvancedHMC
using AffineMaps
using AutoDiffOperators
using Plots
# using Plots
using ValueShapes


Expand All @@ -23,12 +23,12 @@ propcov_result = BAT.bat_sample_impl(posterior,
MCMCSampling(adaptive_transform = f),
context
)
propcov_samples = my_result.result
propcov_samples = propcov_result.result
plot(propcov_samples)


ram_result = BAT.bat_sample_impl(posterior,
MCMCSampling(adaptive_transform = f, tuning = RAMTuning()),
MCMCSampling(adaptive_transform = f, trafo_tuning = RAMTuning()),
context
)
ram_samples = ram_result.result
Expand All @@ -37,7 +37,7 @@ plot(ram_samples)
# Advanced Hamiltonian MC Sampling

hmc_result = BAT.bat_sample_impl(posterior,
MCMCSampling(adaptive_transform = f, proposal = HamiltonianMC(), tuning = StanHMCTuning()),
MCMCSampling(adaptive_transform = f, proposal = HamiltonianMC(), trafo_tuning = StanHMCTuning()),
context
)
hmc_samples = hmc_result.result
Expand Down
4 changes: 2 additions & 2 deletions ext/BATAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ using BAT: MeasureLike, BATMeasure

using BAT: get_context, get_adselector, _NoADSelected
using BAT: getproposal, mcmc_target
using BAT: MCMCState, HMCState, HamiltonianMC, HMCProposalState, MCMCStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, AbstractMCMCTunerInstance
using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, AbstractMCMCTunerState
using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples
using BAT: AbstractTransformTarget
using BAT: RNGPartition, get_rng, set_rng!
using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio
using BAT: get_samples!, get_mcmc_tuning, reset_rng_counters!
using BAT: tuning_init!, tuning_postinit!, tuning_reinit!, tuning_update!, tuning_finalize!, tuning_callback
using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!, tuning_callback
using BAT: totalndof, measure_support, checked_logdensityof
using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJECTED_SAMPLE

Expand Down
2 changes: 1 addition & 1 deletion ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
function BAT.next_cycle!(mc_state::HMCState)
_cleanup_samples(mc_state)

mc_state.info = MCMCStateInfo(mc_state.info, cycle = mc_state.info.cycle + 1)
mc_state.info = MCMCChainStateInfo(mc_state.info, cycle = mc_state.info.cycle + 1)
mc_state.nsamples = 0
mc_state.stepno = 0

Expand Down
79 changes: 57 additions & 22 deletions ext/ahmc_impl/ahmc_tuner_impl.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,102 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).


mutable struct HMCTunerState{A<:AdvancedHMC.AbstractAdaptor} <: AbstractMCMCTunerInstance
struct HMCTrafoTunerState <: AbstractMCMCTunerState end

mutable struct HMCProposalTunerState{A<:AdvancedHMC.AbstractAdaptor} <: AbstractMCMCTunerState
tuning::HMCTuning
target_acceptance::Float64
adaptor::A
end

function (tuning::HMCTuning)(mc_state::HMCState)
θ = first(mc_state.samples).v
adaptor = ahmc_adaptor(tuning, mc_state.proposal.hamiltonian.metric, mc_state.proposal.kernel.τ.integrator, θ)
HMCTunerState(tuning, tuning.target_acceptance, adaptor)
(tuning::HMCTuning)(chain_state::HMCState) = HMCProposalTunerState(tuning, chain_state), HMCTrafoTunerState()

Check warning on line 12 in ext/ahmc_impl/ahmc_tuner_impl.jl

View check run for this annotation

Codecov / codecov/patch

ext/ahmc_impl/ahmc_tuner_impl.jl#L12

Added line #L12 was not covered by tests

HMCTrafoTunerState(tuning::HMCTuning) = HMCTrafoTunerState()

function HMCProposalTunerState(tuning::HMCTuning, chain_state::MCMCChainState)
θ = first(chain_state.samples).v
adaptor = ahmc_adaptor(tuning, chain_state.proposal.hamiltonian.metric, chain_state.proposal.kernel.τ.integrator, θ)
HMCProposalTunerState(tuning, tuning.target_acceptance, adaptor)
end

BAT.create_trafo_tuner_state(tuning::HMCTuning, chain_state::MCMCChainState, iteration::Integer) = HMCTrafoTunerState(tuning)

BAT.create_proposal_tuner_state(tuning::HMCTuning, chain_state::MCMCChainState, iteration::Integer) = HMCProposalTunerState(tuning, chain_state)

BAT.mcmc_tuning_init!!(tuner::HMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) = nothing

function BAT.tuning_init!(tuner::HMCTunerState, mc_state::HMCState, max_nsteps::Integer)
function BAT.mcmc_tuning_init!!(tuner::HMCProposalTunerState, chain_state::HMCState, max_nsteps::Integer)
AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1))
nothing
end

BAT.tuning_postinit!(tuner::HMCTunerState, mc_state::HMCState, samples::DensitySampleVector) = nothing
BAT.mcmc_tuning_reinit!!(tuner::HMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) = nothing

function BAT.tuning_reinit!(tuner::HMCTunerState, mc_state::HMCState, max_nsteps::Integer)
function BAT.mcmc_tuning_reinit!!(tuner::HMCProposalTunerState, chain_state::HMCState, max_nsteps::Integer)
AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1))
nothing
end

function BAT.tuning_update!(tuner::HMCTunerState, mc_state::HMCState, samples::DensitySampleVector)

BAT.mcmc_tuning_postinit!!(tuner::HMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing

BAT.mcmc_tuning_postinit!!(tuner::HMCProposalTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing


BAT.mcmc_tune_post_cycle!!(tuner::HMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing

function BAT.mcmc_tune_post_cycle!!(tuner::HMCProposalTunerState, chain_state::HMCState, samples::DensitySampleVector)
max_log_posterior = maximum(samples.logd)
accept_ratio = eff_acceptance_ratio(mc_state)
accept_ratio = eff_acceptance_ratio(chain_state)
if accept_ratio >= 0.9 * tuner.target_acceptance
mc_state.info = MCMCStateInfo(mc_state.info, tuned = true)
@debug "MCMC chain $(mc_state.info.id) tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(mc_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"
chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = true)
@debug "MCMC chain $(chain_state.info.id) tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"
else
mc_state.info = MCMCStateInfo(mc_state.info, tuned = false)
@debug "MCMC chain $(mc_state.info.id) *not* tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(mc_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"
chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false)
@debug "MCMC chain $(chain_state.info.id) *not* tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"

Check warning on line 56 in ext/ahmc_impl/ahmc_tuner_impl.jl

View check run for this annotation

Codecov / codecov/patch

ext/ahmc_impl/ahmc_tuner_impl.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
end
nothing
end

function BAT.tuning_finalize!(tuner::HMCTunerState, mc_state::HMCState)

BAT.mcmc_tuning_finalize!!(tuner::HMCTrafoTunerState, chain_state::HMCState) = nothing

function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::HMCState)
adaptor = tuner.adaptor
proposal = mc_state.proposal
proposal = chain_state.proposal
AdvancedHMC.finalize!(adaptor)
proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor)
proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor)
nothing
end


function BAT.mcmc_tune_transform!!(
mc_state::MCMCState,
tuner::HMCTunerState,
BAT.tuning_callback(::HMCTrafoTunerState) = nop_func

Check warning on line 74 in ext/ahmc_impl/ahmc_tuner_impl.jl

View check run for this annotation

Codecov / codecov/patch

ext/ahmc_impl/ahmc_tuner_impl.jl#L74

Added line #L74 was not covered by tests

BAT.tuning_callback(::HMCProposalTunerState) = nop_func

Check warning on line 76 in ext/ahmc_impl/ahmc_tuner_impl.jl

View check run for this annotation

Codecov / codecov/patch

ext/ahmc_impl/ahmc_tuner_impl.jl#L76

Added line #L76 was not covered by tests


function BAT.mcmc_tune_post_step!!(
tuner_state::HMCTrafoTunerState,
chain_state::MCMCChainState,
p_accept::Real
)
adaptor = tuner.adaptor
proposal = mc_state.proposal
return chain_state, tuner_state, chain_state.f_transform
end

function BAT.mcmc_tune_post_step!!(
tuner_state::HMCProposalTunerState,
chain_state::MCMCChainState,
p_accept::Real
)
adaptor = tuner_state.adaptor
proposal = chain_state.proposal
tstat = AdvancedHMC.stat(proposal.transition)

AdvancedHMC.adapt!(adaptor, proposal.transition.z.θ, tstat.acceptance_rate)
proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor)
proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor)
tstat = merge(tstat, (is_adapt =true,))

return (tuner, mc_state.f_transform)
return chain_state, tuner_state, chain_state.f_transform
end
2 changes: 1 addition & 1 deletion src/extdefs/ahmc_defs/ahmc_alg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end

export HMCProposalState

const HMCState = MCMCState{<:BATMeasure,
const HMCState = MCMCChainState{<:BATMeasure,
<:RNGPartition,
<:Function,
<:HMCProposalState,
Expand Down
Loading

0 comments on commit f2131ac

Please sign in to comment.