diff --git a/src/samplers/mcmc/mh_sampler.jl b/src/samplers/mcmc/mh_sampler.jl index 811fc2c6f..ea9ad59e9 100644 --- a/src/samplers/mcmc/mh_sampler.jl +++ b/src/samplers/mcmc/mh_sampler.jl @@ -14,13 +14,13 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct RandomWalk{Q<:ContinuousUnivariateDistribution} <: MCMCProposal +@with_kw struct RandomWalk{Q<:Union{AbstractMeasure,Distribution{<:Union{Univariate,Multivariate},Continuous}}} <: MCMCProposal proposaldist::Q = TDist(1.0) end export RandomWalk -struct MHProposalState{Q<:ContinuousUnivariateDistribution} <: MCMCProposalState +struct MHProposalState{Q<:BATMeasure} <: MCMCProposalState proposaldist::Q end export MHProposalState @@ -44,6 +44,11 @@ bat_default(::Type{TransformedMCMC}, ::Val{:burnin}, proposal::RandomWalk, pretr MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500)) +function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, stepno::Integer, sample_type::Integer) + return MCMCSampleID(id, cycle, stepno, sample_type), MCMCSampleID +end + + function _create_proposal_state( proposal::RandomWalk, target::BATMeasure, @@ -51,29 +56,54 @@ function _create_proposal_state( v_init::AbstractVector{<:Real}, rng::AbstractRNG ) - return MHProposalState(proposal.proposaldist) + n_dims = length(v_init) + mv_pdist = batmeasure(_full_random_walk_proposal(proposal.proposaldist, n_dims)) + return MHProposalState(mv_pdist) end -function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, stepno::Integer, sample_type::Integer) - return MCMCSampleID(id, cycle, stepno, sample_type), MCMCSampleID +function _full_random_walk_proposal(m::AbstractMeasure, n_dims::Integer) + x = testvalue(m) + @argcheck x isa AbstractVector{<:Real} && length(x) == n_dims + return m +end + +function _full_random_walk_proposal(m::BATDistMeasure, n_dims::Integer) + d = convert(Distribution, m) + return batmeasure(_full_random_walk_proposal(d, n_dims)) end +function _full_random_walk_proposal(d::Distribution{Multivariate,Continuous}, n_dims::Integer) + @assert false + @argcheck length(d) == n_dims + return d +end -# Theoretical optimally proposal scale for random walk with gaussian proposal, according to -# [Gelman et al., Ann. Appl. Probab. 7 (1) 110 - 120, 1997](https://doi.org/10.1214/aoap/1034625254) -_optimal_proposal_scale(d::ContinuousUnivariateDistribution, n_dims::Integer) = 2.38 / sqrt(n_dims) / sqrt(var(d)) +function _full_random_walk_proposal(d::Normal, n_dims::Integer) + # Theoretical optimally proposal scale for random walk with gaussian proposal, according to + # [Gelman et al., Ann. Appl. Probab. 7 (1) 110 - 120, 1997](https://doi.org/10.1214/aoap/1034625254): + proposal_scale = 2.38 / sqrt(n_dims) -# Determined experimentally for TDist -const _tdist_corr_exp = [0.5, 0.2, 0.14, 0.085, 0.06, 0.045, 0.035, 0.02, 0.015, 0.015] -function _optimal_proposal_scale(d::TDist, n_dims::Integer) - ν_int = round(Int, d.ν) - k = ν_int > 10 ? zero(eltype(_tdist_corr_exp)) : _tdist_corr_exp[ν_int] - 2.38 / sqrt(n_dims) / n_dims^k + @argcheck mean(d) ≈ 0 + σ² = var(d) + Σ = ScalMat(n_dims, proposal_scale^2 * σ²) + return MvNormal(Σ) end +function _full_random_walk_proposal(d::TDist, n_dims::Integer) + # Theoretically optimal proposal scale for gaussian seems to work quite well for + # t-distribution proposals with any degrees of freedom as well: + proposal_scale = 2.38 / sqrt(n_dims) + + ν = dof(d) + Σ = ScalMat(n_dims, proposal_scale^2) + return Distributions.IsoTDist(ν, Σ) +end + + const MHChainState = MCMCChainState{<:BATMeasure, <:RNGPartition, <:Function, <:MHProposalState} + function mcmc_propose!!(mc_state::MHChainState) @unpack target, proposal, f_transform, context = mc_state rng = get_rng(context) @@ -85,11 +115,9 @@ function mcmc_propose!!(mc_state::MHChainState) z_current, logd_z_current = sample_z_current.v, sample_z_current.logd T = eltype(z_current) - n_dims = size(z_current, 1) - - proposal_scale = T(_optimal_proposal_scale(pdist, n_dims)) - z_proposed = z_current + proposal_scale .* T.(rand(rng, pdist, n_dims)) #TODO: check if proposal is symmetric? otherwise need additional factor? + # ToDo: Use gen-context: + z_proposed = z_current + T.(rand(rng, pdist)) x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed) logd_x_proposed = BAT.checked_logdensityof(target, x_proposed) logd_z_proposed = logd_x_proposed + ladj @@ -99,12 +127,9 @@ function mcmc_propose!!(mc_state::MHChainState) mc_state.samples[proposed_x_idx] = DensitySample(x_proposed, logd_x_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing) mc_state.sample_z[2] = DensitySample(z_proposed, logd_z_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing) - # TODO: MD, should we check for symmetriy of proposal distribution? + # TODO: check if proposal is symmetric - otherwise need Hastings correction: p_accept = clamp(exp(logd_z_proposed - logd_z_current), 0, 1) - - @assert p_accept >= 0 - accepted = rand(rng) <= p_accept return mc_state, accepted, p_accept