Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store multivariate proposal in MHProposalState #458

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 47 additions & 22 deletions src/samplers/mcmc/mh_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

$(TYPEDFIELDS)
"""
@with_kw struct RandomWalk{Q<:ContinuousUnivariateDistribution} <: MCMCProposal
@with_kw struct RandomWalk{Q<:Union{AbstractMeasure,Distribution{<:Union{Univariate,Multivariate},Continuous}}} <: MCMCProposal
proposaldist::Q = TDist(1.0)
end

export RandomWalk

struct MHProposalState{Q<:ContinuousUnivariateDistribution} <: MCMCProposalState
struct MHProposalState{Q<:BATMeasure} <: MCMCProposalState
proposaldist::Q
end
export MHProposalState
Expand All @@ -44,36 +44,66 @@
MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500))


function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, stepno::Integer, sample_type::Integer)
return MCMCSampleID(id, cycle, stepno, sample_type), MCMCSampleID
end


function _create_proposal_state(
proposal::RandomWalk,
target::BATMeasure,
context::BATContext,
v_init::AbstractVector{<:Real},
rng::AbstractRNG
)
return MHProposalState(proposal.proposaldist)
n_dims = length(v_init)
mv_pdist = batmeasure(_full_random_walk_proposal(proposal.proposaldist, n_dims))
return MHProposalState(mv_pdist)
end


function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, stepno::Integer, sample_type::Integer)
return MCMCSampleID(id, cycle, stepno, sample_type), MCMCSampleID
function _full_random_walk_proposal(m::AbstractMeasure, n_dims::Integer)
x = testvalue(m)
@argcheck x isa AbstractVector{<:Real} && length(x) == n_dims
return m

Check warning on line 68 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L65-L68

Added lines #L65 - L68 were not covered by tests
end

function _full_random_walk_proposal(m::BATDistMeasure, n_dims::Integer)
d = convert(Distribution, m)
return batmeasure(_full_random_walk_proposal(d, n_dims))

Check warning on line 73 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end

function _full_random_walk_proposal(d::Distribution{Multivariate,Continuous}, n_dims::Integer)
@assert false
@argcheck length(d) == n_dims
return d

Check warning on line 79 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L76-L79

Added lines #L76 - L79 were not covered by tests
end

# Theoretical optimally proposal scale for random walk with gaussian proposal, according to
# [Gelman et al., Ann. Appl. Probab. 7 (1) 110 - 120, 1997](https://doi.org/10.1214/aoap/1034625254)
_optimal_proposal_scale(d::ContinuousUnivariateDistribution, n_dims::Integer) = 2.38 / sqrt(n_dims) / sqrt(var(d))
function _full_random_walk_proposal(d::Normal, n_dims::Integer)

Check warning on line 82 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L82

Added line #L82 was not covered by tests
# Theoretical optimally proposal scale for random walk with gaussian proposal, according to
# [Gelman et al., Ann. Appl. Probab. 7 (1) 110 - 120, 1997](https://doi.org/10.1214/aoap/1034625254):
proposal_scale = 2.38 / sqrt(n_dims)

Check warning on line 85 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L85

Added line #L85 was not covered by tests

# Determined experimentally for TDist
const _tdist_corr_exp = [0.5, 0.2, 0.14, 0.085, 0.06, 0.045, 0.035, 0.02, 0.015, 0.015]
function _optimal_proposal_scale(d::TDist, n_dims::Integer)
ν_int = round(Int, d.ν)
k = ν_int > 10 ? zero(eltype(_tdist_corr_exp)) : _tdist_corr_exp[ν_int]
2.38 / sqrt(n_dims) / n_dims^k
@argcheck mean(d) ≈ 0
σ² = var(d)
Σ = ScalMat(n_dims, proposal_scale^2 * σ²)
return MvNormal(Σ)

Check warning on line 90 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L87-L90

Added lines #L87 - L90 were not covered by tests
end

function _full_random_walk_proposal(d::TDist, n_dims::Integer)
# Theoretically optimal proposal scale for gaussian seems to work quite well for
# t-distribution proposals with any degrees of freedom as well:
proposal_scale = 2.38 / sqrt(n_dims)

ν = dof(d)
Σ = ScalMat(n_dims, proposal_scale^2)
return Distributions.IsoTDist(ν, Σ)
end


const MHChainState = MCMCChainState{<:BATMeasure, <:RNGPartition, <:Function, <:MHProposalState}


function mcmc_propose!!(mc_state::MHChainState)
@unpack target, proposal, f_transform, context = mc_state
rng = get_rng(context)
Expand All @@ -85,11 +115,9 @@

z_current, logd_z_current = sample_z_current.v, sample_z_current.logd
T = eltype(z_current)
n_dims = size(z_current, 1)

proposal_scale = T(_optimal_proposal_scale(pdist, n_dims))

z_proposed = z_current + proposal_scale .* T.(rand(rng, pdist, n_dims)) #TODO: check if proposal is symmetric? otherwise need additional factor?
# ToDo: Use gen-context:
z_proposed = z_current + T.(rand(rng, pdist))
x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed)
logd_x_proposed = BAT.checked_logdensityof(target, x_proposed)
logd_z_proposed = logd_x_proposed + ladj
Expand All @@ -99,12 +127,9 @@
mc_state.samples[proposed_x_idx] = DensitySample(x_proposed, logd_x_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing)
mc_state.sample_z[2] = DensitySample(z_proposed, logd_z_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing)

# TODO: MD, should we check for symmetriy of proposal distribution?
# TODO: check if proposal is symmetric - otherwise need Hastings correction:
p_accept = clamp(exp(logd_z_proposed - logd_z_current), 0, 1)


@assert p_accept >= 0

accepted = rand(rng) <= p_accept

return mc_state, accepted, p_accept
Expand Down
Loading