-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathAdaptation.jl
66 lines (55 loc) · 1.59 KB
/
Adaptation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
module Adaptation
export Adaptation
using LinearAlgebra: LinearAlgebra
using Statistics: Statistics
using SimpleUnPack: @unpack, @pack!
using ..AdvancedHMC: DEBUG, AbstractScalarOrVec
using DocStringExtensions
"""
$(TYPEDEF)
Abstract type for HMC adaptors.
"""
abstract type AbstractAdaptor end
function getM⁻¹ end
function getϵ end
function adapt! end
function reset! end
function initialize! end
function finalize! end
export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹
struct NoAdaptation <: AbstractAdaptor end
export NoAdaptation
include("stepsize.jl")
export StepSizeAdaptor, NesterovDualAveraging
include("massmatrix.jl")
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov
##
## Composite adaptors
## TODO: generalise this to a list of adaptors
##
struct NaiveHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAdaptor
pc::M
ssa::Tssa
end
Base.show(io::IO, a::NaiveHMCAdaptor) =
print(io, "NaiveHMCAdaptor(pc=$(a.pc), ssa=$(a.ssa))")
getM⁻¹(ca::NaiveHMCAdaptor) = getM⁻¹(ca.pc)
getϵ(ca::NaiveHMCAdaptor) = getϵ(ca.ssa)
# TODO: implement consensus adaptor
function adapt!(
nca::NaiveHMCAdaptor,
θ::AbstractVecOrMat{<:AbstractFloat},
α::AbstractScalarOrVec{<:AbstractFloat},
)
adapt!(nca.ssa, θ, α)
adapt!(nca.pc, θ, α)
end
function reset!(aca::NaiveHMCAdaptor)
reset!(aca.ssa)
reset!(aca.pc)
end
initialize!(adaptor::NaiveHMCAdaptor, n_adapts::Int) = nothing
finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa)
include("stan_adaptor.jl")
export NaiveHMCAdaptor, StanHMCAdaptor
end # module