Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/new-gibbs
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Feb 27, 2024
2 parents b1dcadf + 4b5e4d7 commit 0f0bfd5
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 156 deletions.
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Release 0.30.5

- `essential/ad.jl` is removed, `ForwardDiff` and `ReverseDiff` integrations via `LogDensityProblemsAD` are moved to `DynamicPPL` and live in corresponding package extensions.
- `LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)` (i.e. the single argument method) is moved to `Inference` module. It will create `ADgradient` using the `adtype` information stored in `context` field of ``.
- `getADbackend` function is renamed to `getADType`, the interface is preserved, but packages that previously used `getADbackend` should be updated to use `getADType`.
- `TuringTag` for ForwardDiff is also removed, now `DynamicPPLTag` is defined in `DynamicPPL` package and should serve the same [purpose](https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/).

# Release 0.30.0

- [`ADTypes.jl`](https://github.com/SciML/ADTypes.jl) replaced Turing's global AD backend. Users should now specify the desired `ADType` directly in sampler constructors, e.g., `HMC(0.1, 10; adtype=AutoForwardDiff(; chunksize))`, or `HMC(0.1, 10; adtype=AutoReverseDiff(false))` (`false` indicates not to use compiled tape).
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.3"
version = "0.30.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -46,7 +46,7 @@ TuringOptimExt = "Optim"

[compat]
ADTypes = "0.2"
AbstractMCMC = "5"
AbstractMCMC = "5.2"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
Expand All @@ -58,7 +58,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.24.2"
DynamicPPL = "0.24.7"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
11 changes: 3 additions & 8 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import Random

const PROGRESS = Ref(true)

# TODO: remove `PROGRESS` and this function in favour of `AbstractMCMC.PROGRESS`
"""
setprogress!(progress::Bool)
Expand All @@ -28,18 +29,12 @@ Enable progress logging in Turing if `progress` is `true`, and disable it otherw
function setprogress!(progress::Bool)
@info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally"
PROGRESS[] = progress
AbstractMCMC.setprogress!(progress; silent=true)
# TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3
AdvancedVI.turnprogress(progress)
return progress
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct TuringTag end

# Allow Turing tag in gradient etc. calls of the log density function
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:LogDensityFunction}, ::AbstractArray{V}) where {V} = true

# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")
Expand Down
3 changes: 0 additions & 3 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ using StatsFuns: logsumexp, softmax
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote

import AdvancedPS
import LogDensityProblems
import LogDensityProblemsAD

include("container.jl")
include("ad.jl")

export @model,
@varname,
Expand Down
45 changes: 0 additions & 45 deletions src/essential/ad.jl

This file was deleted.

63 changes: 31 additions & 32 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import AdvancedHMC; const AHMC = AdvancedHMC
import AdvancedMH; const AMH = AdvancedMH
import AdvancedPS
import BangBang
import ..Essential: getADbackend
import EllipticalSliceSampling
import LogDensityProblems
import LogDensityProblemsAD
Expand Down Expand Up @@ -79,7 +78,6 @@ abstract type ParticleInference <: InferenceAlgorithm end
abstract type Hamiltonian <: InferenceAlgorithm end
abstract type StaticHamiltonian <: Hamiltonian end
abstract type AdaptiveHamiltonian <: Hamiltonian end
getADbackend(alg::Hamiltonian) = alg.adtype

"""
ExternalSampler{S<:AbstractSampler}
Expand All @@ -99,6 +97,20 @@ Wrap a sampler so it can be used as an inference algorithm.
"""
externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler)

getADType(spl::Sampler) = getADType(spl.alg)
getADType(::SampleFromPrior) = AutoForwardDiff(; chunksize=0)

getADType(ctx::DynamicPPL.SamplingContext) = getADType(ctx.sampler)
getADType(ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.NodeTrait(ctx), ctx)
getADType(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = AutoForwardDiff(; chunksize=0)
getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.childcontext(ctx))

getADType(alg::Hamiltonian) = alg.adtype

function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(getADType(ℓ.context), ℓ)
end

function LogDensityProblems.logdensity(
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
x::NamedTuple
Expand All @@ -116,6 +128,23 @@ DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.l
# Algorithm for sampling from the prior
struct Prior <: InferenceAlgorithm end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:Prior},
state=nothing;
kwargs...,
)
vi = last(DynamicPPL.evaluate!!(
model,
VarInfo(),
SamplingContext(
rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()
)
))
return vi, nothing
end

"""
mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real)
Expand Down Expand Up @@ -231,36 +260,6 @@ function AbstractMCMC.sample(
chain_type=chain_type, progress=progress, kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::Prior,
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
chain_type=DynamicPPL.default_chain_type(alg),
progress=PROGRESS[],
kwargs...
)
return AbstractMCMC.sample(rng, model, SampleFromPrior(), ensemble, N, n_chains;
chain_type, progress, kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::Prior,
N::Integer;
chain_type=DynamicPPL.default_chain_type(alg),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
kwargs...
)
return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N;
chain_type, initial_state, progress, kwargs...)
end

##########################
# Chain making utilities #
##########################
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function AbstractMCMC.step(
vis[1],
map(vis) do vi
vi = DynamicPPL.link!!(vi, spl, model)
AMH.Transition(vi[spl], getlogp(vi))
AMH.Transition(vi[spl], getlogp(vi), false)
end
)

Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ function propose!!(

# Create a sampler and the previous transition.
mh_sampler = AMH.MetropolisHastings(dt)
prev_trans = AMH.Transition(vt, getlogp(vi))
prev_trans = AMH.Transition(vt, getlogp(vi), false)

# Make a new transition.
densitymodel = AMH.DensityModel(
Expand Down Expand Up @@ -421,7 +421,7 @@ function propose!!(

# Create a sampler and the previous transition.
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
prev_trans = AMH.Transition(vals, getlogp(vi))
prev_trans = AMH.Transition(vals, getlogp(vi), false)

# Make a new transition.
densitymodel = AMH.DensityModel(
Expand Down
14 changes: 7 additions & 7 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,17 @@ function optim_function(
model::Model,
estimator::Union{MLE, MAP};
constrained::Bool=true,
autoad::Union{Nothing, AbstractADType}=NoAD(),
adtype::Union{Nothing, AbstractADType}=NoAD(),
)
if autoad === nothing
Base.depwarn("the use of `autoad=nothing` is deprecated, please use `autoad=SciMLBase.NoAD()`", :optim_function)
if adtype === nothing
Base.depwarn("the use of `adtype=nothing` is deprecated, please use `adtype=SciMLBase.NoAD()`", :optim_function)
end

obj, init, t = optim_objective(model, estimator; constrained=constrained)

l(x, _) = obj(x)
f = if autoad isa AbstractADType && autoad !== NoAD()
OptimizationFunction(l, autoad)
f = if adtype isa AbstractADType && adtype !== NoAD()
OptimizationFunction(l, adtype)
else
OptimizationFunction(
l;
Expand All @@ -310,10 +310,10 @@ function optim_problem(
estimator::Union{MAP, MLE};
constrained::Bool=true,
init_theta=nothing,
autoad::Union{Nothing, AbstractADType}=NoAD(),
adtype::Union{Nothing, AbstractADType}=NoAD(),
kwargs...,
)
f, init, transform = optim_function(model, estimator; constrained=constrained, autoad=autoad)
f, init, transform = optim_function(model, estimator; constrained=constrained, adtype=adtype)

u0 = init_theta === nothing ? init() : init(init_theta)
prob = OptimizationProblem(f, u0; kwargs...)
Expand Down
66 changes: 22 additions & 44 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,17 @@
# TODO(torfjelde): Find a better solution.
struct Vec{N,B} <: Bijectors.Bijector
b::B
size::NTuple{N, Int}
end

Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size)

Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz)
Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n)

function Bijectors.with_logabsdet_jacobian(f::Vec, x)
return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x)
end

function Bijectors.transform(f::Vec, x::AbstractVector)
# Reshape into shape compatible with wrapped bijector and then `vec` again.
return vec(f.b(reshape(x, f.size)))
end

function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N
# Reshape into shape compatible with original (forward) bijector and then `vec` again.
return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size)))))
end

function Bijectors.transform(f::Vec, x::AbstractMatrix)
# At the moment we do batching for higher-than-1-dim spaces by simply using
# lists of inputs rather than `AbstractArray` with `N + 1` dimension.
cols = Iterators.Stateful(eachcol(x))
# Make `init` a matrix to ensure type-stability
init = reshape(f(first(cols)), :, 1)
return mapreduce(f, hcat, cols; init = init)
end

function Bijectors.logabsdetjac(f::Vec, x::AbstractVector)
return Bijectors.logabsdetjac(f.b, reshape(x, f.size))
end
# TODO: Move to Bijectors.jl if we find further use for this.
"""
wrap_in_vec_reshape(f, in_size)
function Bijectors.logabsdetjac(f::Vec, x::AbstractMatrix)
return map(eachcol(x)) do x_
Bijectors.logabsdetjac(f, x_)
end
Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces
a vector of length `prod(Bijectors.output(f, in_size))`.
"""
function wrap_in_vec_reshape(f, in_size)
vec_in_length = prod(in_size)
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
out_size = Bijectors.output_size(f, in_size)
vec_out_length = prod(out_size)
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
return reshape_outer f reshape_inner
end


Expand Down Expand Up @@ -83,7 +55,7 @@ function Bijectors.bijector(
if d isa Distributions.UnivariateDistribution
b
else
Vec(b, size(d))
wrap_in_vec_reshape(b, size(d))
end
end

Expand All @@ -106,7 +78,10 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model)
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model)
# Setup.
varinfo = DynamicPPL.VarInfo(model)
num_params = length(varinfo[DynamicPPL.SampleFromPrior()])
# Use linked `varinfo` to determine the correct number of parameters.
# TODO: Replace with `length` once this is implemented for `VarInfo`.
varinfo_linked = DynamicPPL.link(varinfo, model)
num_params = length(varinfo_linked[:])

# initial params
μ = randn(rng, num_params)
Expand Down Expand Up @@ -134,7 +109,10 @@ function AdvancedVI.update(
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
θ::AbstractArray,
)
μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
# `length(td.dist) != length(td)` if `td.transform` changes the dimensionality,
# so we need to use the length of the underlying distribution `td.dist` here.
# TODO: Check if we can get away with `view` instead of `getindex` for all AD backends.
μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end]
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω))
end

Expand Down
11 changes: 0 additions & 11 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,6 @@

end

@testset "tag" begin
for chunksize in (0, 1, 10)
ad = Turing.AutoForwardDiff(; chunksize=chunksize)
@test ad === Turing.AutoForwardDiff(; chunksize=chunksize)
@test Turing.Essential.standardtag(ad)
for standardtag in (false, 0, 1)
@test !Turing.Essential.standardtag(Turing.AutoForwardDiff(; chunksize=chunksize, tag=standardtag))
end
end
end

@testset "ReverseDiff compiled without linking" begin
f = DynamicPPL.LogDensityFunction(gdemo_default)
θ = DynamicPPL.getparams(f)
Expand Down
Loading

0 comments on commit 0f0bfd5

Please sign in to comment.