Skip to content

Commit

Permalink
Rename mcmc_iterate! -> mcmc_iterate!!
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Sep 26, 2024
1 parent d711fd2 commit a084518
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 44 deletions.
2 changes: 1 addition & 1 deletion ext/BATAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using BAT: MCMCState, HMCState, HamiltonianMC, HMCProposalState, MCMCStateInfo,
using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples
using BAT: AbstractTransformTarget
using BAT: RNGPartition, get_rng, set_rng!
using BAT: mcmc_step!, nsamples, nsteps, samples_available, eff_acceptance_ratio
using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio
using BAT: get_samples!, get_mcmc_tuning, reset_rng_counters!
using BAT: tuning_init!, tuning_postinit!, tuning_reinit!, tuning_update!, tuning_finalize!, tuning_callback
using BAT: totalndof, measure_support, checked_logdensityof
Expand Down
6 changes: 3 additions & 3 deletions src/samplers/mcmc/chain_pool_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ function mcmc_init!(

@debug "Testing $(length(new_tuners)) candidate MCMC chain state(s)."

mcmc_iterate!(
new_mc_states, new_tuners, new_temperers = mcmc_iterate!!(
new_outputs, new_mc_states;
tuners = new_tuners,temperers = new_temperers,
tuners = new_tuners, temperers = new_temperers,
max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50),
nonzero_weights = nonzero_weights
)
Expand All @@ -124,7 +124,7 @@ function mcmc_init!(
@debug "Found $(length(viable_idxs)) viable MCMC chain state(s)."

if !isempty(viable_tuners)
mcmc_iterate!(
viable_mc_states, viable_tuners, viable_temperers = mcmc_iterate!!(
viable_outputs, viable_mc_states;
tuners = viable_tuners, temperers = viable_temperers,
max_nsteps = init_alg.nsteps_init,
Expand Down
34 changes: 20 additions & 14 deletions src/samplers/mcmc/mcmc_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ BAT.get_samples!(samples::DensitySampleVector, chain::SomeMCMCIter, nonzero_weig
BAT.next_cycle!(chain::SomeMCMCIter)::SomeMCMCIter
BAT.mcmc_step!(
BAT.mcmc_step!!(
chain::SomeMCMCIter
callback::Function,
)::nothing
Expand All @@ -117,8 +117,8 @@ The following methods are implemented by default:
getproposal(chain::MCMCIterator)
mcmc_target(chain::MCMCIterator)
DensitySampleVector(chain::MCMCIterator)
mcmc_iterate!(chain::MCMCIterator, ...)
mcmc_iterate!(chains::AbstractVector{<:MCMCIterator}, ...)
mcmc_iterate!!(chain::MCMCIterator, ...)
mcmc_iterate!!(chains::AbstractVector{<:MCMCIterator}, ...)
isvalidchain(chain::MCMCIterator)
isviablechain(chain::MCMCIterator)
```
Expand Down Expand Up @@ -169,7 +169,7 @@ function get_samples! end

function next_cycle! end

function mcmc_step! end
function mcmc_step!! end


abstract type AbstractMCMCTunerInstance end
Expand Down Expand Up @@ -200,10 +200,11 @@ function isvalidstate end
function isviablestate end


function mcmc_iterate! end
function mcmc_iterate!! end

# TODO: MD, reincorporate user callback
# TODO: MD, incorporate use of Tempering, so far temperer is not used
function mcmc_iterate!(
function mcmc_iterate!!(
output::Union{DensitySampleVector,Nothing},
mc_state::MCMCIterator,
tuner::Union{AbstractMCMCTunerInstance,Nothing},
Expand All @@ -224,7 +225,8 @@ function mcmc_iterate!(
(nsteps(mc_state) - start_nsteps) < max_nsteps &&
(time() - start_time) < max_time
)
mcmc_step!(mc_state, tuner, temperer)
mc_state, tuner, temperer = mcmc_step!!(mc_state, tuner, temperer)

if !isnothing(output)
get_samples!(output, mc_state, nonzero_weights)
end
Expand All @@ -241,11 +243,11 @@ function mcmc_iterate!(
elapsed_time = current_time - start_time
@debug "Finished iteration over MCMC chain $(mc_state.info.id), completed $(nsteps(mc_state) - start_nsteps) steps and produced $(nsamples(mc_state) - start_nsamples) samples in $(@sprintf "%.1f s" elapsed_time)."

return nothing
return mc_state, tuner, temperer
end


function mcmc_iterate!(
function mcmc_iterate!!(
output::Union{DensitySampleVector,Nothing},
mc_state::MCMCIterator;
tuner::Union{AbstractMCMCTunerInstance, Nothing} = nothing,
Expand All @@ -254,16 +256,16 @@ function mcmc_iterate!(
max_time::Real = Inf,
nonzero_weights::Bool = true
)
mcmc_iterate!(
mc_state_new, tuner_new, temperer_new = mcmc_iterate!!(
output, mc_state, tuner, temperer;
max_nsteps = max_nsteps, max_time = max_time, nonzero_weights = nonzero_weights
)

return nothing
return mc_state_new, tuner_new, temperer_new
end


function mcmc_iterate!(
function mcmc_iterate!!(
outputs::Union{AbstractVector{<:DensitySampleVector},Nothing},
mc_states::AbstractVector{<:MCMCIterator};
tuners::Union{AbstractVector{<:AbstractMCMCTunerInstance},Nothing} = nothing,
Expand All @@ -281,11 +283,15 @@ function mcmc_iterate!(
tnrs = isnothing(tuners) ? fill(nothing, size(mc_states)...) : tuners
tmrs = isnothing(temperers) ? fill(nothing, size(mc_states)...) : temperers

mc_states_new = similar(mc_states)
tuners_new = similar(tnrs)
temperers_new = similar(tmrs)

@sync for i in eachindex(outs, mc_states, tnrs)
Base.Threads.@spawn mcmc_iterate!(outs[i], mc_states[i]; tuner = tnrs[i], temperer = tmrs[i], kwargs...)
Base.Threads.@spawn mc_states_new[i], tuners_new[i], temperers_new[i] = mcmc_iterate!!(outs[i], mc_states[i]; tuner = tnrs[i], temperer = tmrs[i], kwargs...)
end

return nothing
return mc_states_new, tuners_new, temperers_new
end


Expand Down
2 changes: 1 addition & 1 deletion src/samplers/mcmc/mcmc_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function bat_sample_impl(target::BATMeasure, sampling::MCMCSampling, context::BA

next_cycle!.(mc_states)

mcmc_iterate!(
mc_states, REMOVE_dummy_tuners, REMOVE_dummy_temperers = mcmc_iterate!!(
chain_outputs,
mc_states;
max_nsteps = sampling.nsteps,
Expand Down
44 changes: 39 additions & 5 deletions src/samplers/mcmc/mcmc_state.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).


# struct MCMCStates{
# C<:MCMCState,
# ...
# }
# chain_state::C
# proposal_tuner_state::PT
# transform_tuner_state::TT
# tempering_state::T
# end

# TODO: MD, adjust docstring to new typestructure
# TODO: MD, use Accessors.jl to make immutable
# TODO: MD, rename to MCMCChainState
mutable struct MCMCState{
M<:BATMeasure,
PR<:RNGPartition,
Expand Down Expand Up @@ -116,8 +129,9 @@ function DensitySampleVector(mc_state::MCMCState)
DensitySampleVector(sample_type(mc_state), totalndof(varshape(mcmc_target(mc_state))))
end


function mcmc_step!(mc_state::MCMCState, tuner::Union{AbstractMCMCTunerInstance, Nothing}, temperer::Union{AbstractMCMCTemperingInstance, Nothing})
# TODO: MD, make into !!
# TODO: MD, make NoOpTunerState to avoid Union nothing in type
function mcmc_step!!(mc_state::MCMCState, tuner_state::Union{AbstractMCMCTunerInstance, Nothing}, temperer::Union{AbstractMCMCTemperingInstance, Nothing}) # ,proposal_tuner_state
# TODO: MD, include sample_z in _cleanup_samples()
_cleanup_samples(mc_state)
reset_rng_counters!(mc_state)
Expand All @@ -133,16 +147,21 @@ function mcmc_step!(mc_state::MCMCState, tuner::Union{AbstractMCMCTunerInstance,

mc_state, accepted, p_accept = mcmc_propose!!(mc_state)

tuner_new, f_transform_tuned = mcmc_tune_transform!!(mc_state, tuner, p_accept)
# TODO: MD, return a bool if the transform is changed
tuner_new, f_transform_tuned = mcmc_tune_transform!!(mc_state, tuner_state, p_accept)

# TODO: MD, Discuss updating of 'sample_z' due to possibly changed 'f_transform' during transfom tuning_callback
#proosal_new, proposal_tuner_new = mcmc_tune_proposal!!(mc_state, proposal_tuner_state)

current = _current_sample_idx(mc_state)
proposed = _proposed_sample_idx(mc_state)

_accept_reject!(mc_state, accepted, p_accept, current, proposed)

nothing
#mc_state_new =
#temperer_new = temperer_new


return mc_state, tuner_new, temperer
end


Expand Down Expand Up @@ -212,3 +231,18 @@ function samples_available(mc_state::MCMCState)
i = _current_sample_idx(mc_state)
mc_state.samples.info.sampletype[i] == ACCEPTED_SAMPLE
end

function mcmc_update_z_position!!(mc_state::MCMCState)

proposed_sample_x = proposed_sample(mc_state)

x_proposed, logd_x_proposed = proposed_sample_x.v, proposed_sample_x.logd

z_new, ladj_inv = with_logabsdet_jacobian(inverse(mc_state.f_transform), x_proposed)

logd_z_new = logd_x_proposed - ladj_inv

mc_state_new = @set mc_state.sample_z[2].v, mc_state.sample_z[2].logd = z_new, logd_z_new

return mc_state_new
end
5 changes: 4 additions & 1 deletion src/samplers/mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ end

export AdaptiveMHTuning

# TODO: MD, make immutable and use Accessors.jl
mutable struct ProposalCovTunerState{
S<:MCMCBasicStats
} <: AbstractMCMCTunerInstance
Expand Down Expand Up @@ -121,7 +122,7 @@ function tuning_postinit!(tuner::ProposalCovTunerState, mc_state::MHState, sampl
append!(stats, samples)
end


# TODO, MD: Rename to mcmc_tune_transform_next_cylce!!()
function tuning_update!(tuner::ProposalCovTunerState, mc_state::MHState, samples::DensitySampleVector)
tuning = tuner.tuning
stats = tuner.stats
Expand Down Expand Up @@ -185,6 +186,8 @@ tuning_callback(::ProposalCovTunerState) = nop_func

tuning_callback(::Nothing) = nop_func


# add a boold to return if the transfom changes
function mcmc_tune_transform!!(
mc_state::MCMCState,
tuner::ProposalCovTunerState,
Expand Down
1 change: 1 addition & 0 deletions src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ tuning_callback(::RAMTunerState) = nop_func

default_adaptive_transform(tuner::RAMTuning) = TriangularAffineTransform()

# Return mc_state instead of f_transform
function mcmc_tune_transform!!(
mc_state::MCMCState,
tuner::RAMTunerState,
Expand Down
5 changes: 4 additions & 1 deletion src/samplers/mcmc/mh_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ const MHState = MCMCState{<:BATMeasure,
<:BATContext
}

# TODO: MD, should this be a !! function?
function mcmc_propose!!(mc_state::MHState)
@unpack target, proposal, f_transform, context = mc_state
rng = get_rng(context)
Expand Down Expand Up @@ -95,6 +94,10 @@ function mcmc_propose!!(mc_state::MHState)

accepted = rand(rng) <= p_accept

# if accepted
# mc_state_new = mcmc_update_z_position!!(mc_state)
# end

return mc_state, accepted, p_accept
end

Expand Down
4 changes: 2 additions & 2 deletions src/samplers/mcmc/multi_cycle_burnin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function mcmc_burnin!(

tuning_reinit!.(tuners, mc_states, burnin.nsteps_per_cycle)

mcmc_iterate!(
mc_states, tuners, REMOVE_dummy_temperers = mcmc_iterate!!(
new_outputs, mc_states;
tuners = tuners,
max_nsteps = burnin.nsteps_per_cycle,
Expand Down Expand Up @@ -92,7 +92,7 @@ function mcmc_burnin!(

next_cycle!.(mc_states)

mcmc_iterate!(
mc_states, REMOVE_dummy_tuners, REMOVE_dummy_temperers = mcmc_iterate!!(
outputs, mc_states,
max_nsteps = burnin.nsteps_final,
nonzero_weights = nonzero_weights
Expand Down
6 changes: 3 additions & 3 deletions test/samplers/mcmc/test_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ import AdvancedHMC
BAT.tuning_init!(tuner, mc_state, 0)
BAT.tuning_reinit!(tuner, mc_state, div(nsteps, 10))
samples = DensitySampleVector(mc_state)
BAT.mcmc_iterate!(samples, mc_state; tuner = tuner, max_nsteps = nsteps, nonzero_weights = false)
mc_state, tuner, REMOVE_dummy_temperer = BAT.mcmc_iterate!!(samples, mc_state; tuner = tuner, max_nsteps = nsteps, nonzero_weights = false)
@test mc_state.stepno == nsteps
@test minimum(samples.weight) == 0
@test isapprox(length(samples), nsteps, atol = 20)
@test length(samples) == sum(samples.weight)
@test BAT.test_dist_samples(unshaped(objective), samples)

samples = DensitySampleVector(mc_state)
BAT.mcmc_iterate!(samples, mc_state, max_nsteps = 10^3, nonzero_weights = true)
mc_state, REMOVE_dummy_tuner, REMOVE_dummy_temperer = BAT.mcmc_iterate!!(samples, mc_state, max_nsteps = 10^3, nonzero_weights = true)
@test minimum(samples.weight) == 1
end

Expand Down Expand Up @@ -90,7 +90,7 @@ import AdvancedHMC
callback
)

BAT.mcmc_iterate!(
mc_states, REMOVE_dummy_tuners, REMOVE_dummy_temperers = BAT.mcmc_iterate!!(
outputs,
mc_states;
max_nsteps = div(max_nsteps, length(mc_states)),
Expand Down
26 changes: 13 additions & 13 deletions test/samplers/mcmc/test_mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI
# TODO: MD, Reactivate type inference tests
# @test @inferred(BAT.MCMCState(sampling, target, 1, unshaped(v_init, varshape(target)), deepcopy(context))) isa BAT.MHState
# chain = @inferred(BAT.MCMCState(sampling, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)))
chain = BAT.MCMCState(sampling, target, 1, unshaped(v_init, varshape(target)), deepcopy(context))
samples = DensitySampleVector(chain)
BAT.mcmc_iterate!(samples, chain, max_nsteps = 10^5, nonzero_weights = false)
@test chain.stepno == 10^5
mc_state = BAT.MCMCState(sampling, target, 1, unshaped(v_init, varshape(target)), deepcopy(context))
samples = DensitySampleVector(mc_state)
mc_state, REMOVE_dummy_tuner, REMOVE_dummy_temperer = BAT.mcmc_iterate!!(samples, mc_state, max_nsteps = 10^5, nonzero_weights = false)
@test mc_state.stepno == 10^5
@test minimum(samples.weight) == 0
@test isapprox(length(samples), 10^5, atol = 20)
@test length(samples) == sum(samples.weight)
@test isapprox(mean(samples), [1, -1, 2], atol = 0.2)
@test isapprox(cov(samples), cov(unshaped(objective)), atol = 0.3)

samples = DensitySampleVector(chain)
BAT.mcmc_iterate!(samples, chain, max_nsteps = 10^3, nonzero_weights = true)
samples = DensitySampleVector(mc_state)
mc_state, REMOVE_dummy_tuner, REMOVE_dummy_temperer = BAT.mcmc_iterate!!(samples, mc_state, max_nsteps = 10^3, nonzero_weights = true)
@test minimum(samples.weight) == 1
end

Expand Down Expand Up @@ -68,29 +68,29 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI
context
))

(chains, tuners, outputs) = init_result
(mc_states, tuners, outputs) = init_result

# TODO: MD, Reactivate, for some reason fail
# @test chains isa AbstractVector{<:BAT.MHState}
# @test mc_states isa AbstractVector{<:BAT.MHState}
# @test tuners isa AbstractVector{<:BAT.ProposalCovTunerState}
@test outputs isa AbstractVector{<:DensitySampleVector}

BAT.mcmc_burnin!(
outputs,
tuners,
chains,
mc_states,
sampling,
callback
)

BAT.mcmc_iterate!(
mc_states, REMOVE_dummy_tuners, REMOVE_dummy_temperers = BAT.mcmc_iterate!!(
outputs,
chains;
max_nsteps = div(max_nsteps, length(chains)),
mc_states;
max_nsteps = div(max_nsteps, length(mc_states)),
nonzero_weights = nonzero_weights
)

samples = DensitySampleVector(first(chains))
samples = DensitySampleVector(first(mc_states))
append!.(Ref(samples), outputs)

@test length(samples) == sum(samples.weight)
Expand Down

0 comments on commit a084518

Please sign in to comment.