-
Notifications
You must be signed in to change notification settings - Fork 221
/
Copy pathabstractmcmc.jl
167 lines (148 loc) · 5.96 KB
/
abstractmcmc.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
struct TuringState{S,F}
state::S
logdensity::F
end
state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
# TODO: We should probably rename this `getparams` since it returns something
# very different from `Turing.Inference.getparams`.
θ = getparams(f.model, transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
return Transition(f.model, varinfo, transition)
end
state_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, state) = TuringState(state, f)
function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transition)
return transition_to_turing(parent(f), transition)
end
function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo
function varinfo(state::TuringState)
θ = getparams(DynamicPPL.getmodel(state.logdensity), state.state)
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
end
# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)
end
getstats(transition::AdvancedHMC.Transition) = transition.stat
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
return getvarinfo(LogDensityProblemsAD.parent(f))
end
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
function setvarinfo(
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
)
return LogDensityProblemsAD.ADgradient(
adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo)
)
end
"""
recompute_logprob!!(rng, model, sampler, state)
Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:ExternalSampler},
state,
)
# Re-using the log-density function from the `state` and updating only the `model` field,
# since the `model` might now contain different conditioning values.
f = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype)
# Recompute the log-probability with the new `model`.
state_inner = recompute_logprob!!(
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
)
return state_to_turing(f, state_inner)
end
function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedHMC.AbstractHMCSampler,
state::AdvancedHMC.HMCState,
)
# Construct hamiltionian.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
# Re-compute the log-probability and gradient.
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, state.transition.z.θ, state.transition.z.r
)
end
function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedMH.MetropolisHastings,
state::AdvancedMH.Transition,
)
logdensity = model.logdensity
return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params)
end
# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler};
initial_state=nothing,
initial_params=nothing,
kwargs...,
)
alg = sampler_wrapper.alg
sampler = alg.sampler
# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(alg.adtype, DynamicPPL.LogDensityFunction(model))
# Link the varinfo if needed.
varinfo = getvarinfo(f)
if requires_unconstrained_space(alg)
if initial_params !== nothing
# If we have initial parameters, we need to set the varinfo before linking.
varinfo = DynamicPPL.link(DynamicPPL.unflatten(varinfo, initial_params), model)
# Extract initial parameters in unconstrained space.
initial_params = varinfo[:]
else
varinfo = DynamicPPL.link(varinfo, model)
end
end
f = setvarinfo(f, varinfo, alg.adtype)
# Then just call `AdvancedHMC.step` with the right arguments.
if initial_state === nothing
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs...
)
else
transition_inner, state_inner = AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(f),
sampler,
initial_state;
initial_params,
kwargs...,
)
end
# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
)
# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
end