Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump DynamicPPL to v0.25 #2197

Merged
merged 17 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.9"
version = "0.31.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Expand All @@ -29,7 +30,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -47,18 +47,19 @@ TuringOptimExt = "Optim"
[compat]
ADTypes = "0.2"
AbstractMCMC = "5.2"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
AdvancedVI = "0.2"
BangBang = "0.3"
BangBang = "0.4"
Bijectors = "0.13.6"
DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.24.10"
DynamicPPL = "0.25.1"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand All @@ -70,7 +71,6 @@ Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.37.1, 2"
Setfield = "0.8, 1"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
Statistics = "1.6"
StatsAPI = "1.6"
Expand Down
18 changes: 9 additions & 9 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ module TuringOptimExt

if isdefined(Base, :get_extension)
import Turing
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import Optim
else
import ..Turing
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import ..Optim
end

Expand Down Expand Up @@ -80,7 +80,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff
# Hessian is computed with respect to the untransformed parameters.
linked = DynamicPPL.istrans(m.f.varinfo)
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
end

# Calculate the Hessian, which is the information matrix because the negative of the log likelihood was optimized
Expand All @@ -89,7 +89,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff

# Link it back if we invlinked it.
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
end

return NamedArrays.NamedArray(info, (varnames, varnames))
Expand Down Expand Up @@ -227,8 +227,8 @@ function _optimize(
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)

# Optimize!
Expand All @@ -241,10 +241,10 @@ function _optimize(

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
Setfield.@set! f.varinfo = DynamicPPL.invlink(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
f = Accessors.@set f.varinfo = DynamicPPL.invlink(f.varinfo, model)
vals = DynamicPPL.getparams(f)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
Expand Down
4 changes: 3 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ using DynamicPPL: DynamicPPL, LogDensityFunction
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import NamedArrays
import Setfield
import Accessors
import StatsAPI
import StatsBase

using Accessors: Accessors

import Printf
import Random

Expand Down
2 changes: 1 addition & 1 deletion src/experimental/Experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Experimental
using Random: Random
using AbstractMCMC: AbstractMCMC
using DynamicPPL: DynamicPPL, VarName
using Setfield: Setfield
using Accessors: Accessors

using DocStringExtensions: TYPEDFIELDS
using Distributions
Expand Down
10 changes: 5 additions & 5 deletions src/experimental/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ Returns the preferred value type for a variable with the given `varinfo`.
preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict
preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple
function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo)
# We can only do this in the scenario where all the varnames are `Setfield.IdentityLens`.
# We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`.
namedtuple_compatible = all(varinfo.metadata) do md
eltype(md.vns) <: VarName{<:Any,Setfield.IdentityLens}
eltype(md.vns) <: VarName{<:Any,typeof(identity)}
end
return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict
end
Expand Down Expand Up @@ -321,8 +321,8 @@ function AbstractMCMC.step(
)

# Update the `states` and `varinfos`.
states = Setfield.setindex(states, new_state_local, index)
varinfos = Setfield.setindex(varinfos, new_varinfo_local, index)
states = Accessors.setindex(states, new_state_local, index)
varinfos = Accessors.setindex(varinfos, new_varinfo_local, index)
end

# Combine the resulting varinfo objects.
Expand All @@ -349,7 +349,7 @@ function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler
# NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide
# a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact
# same `selector` as before but now with `rerun` set to `true` if needed.
return Setfield.@set sampler.selector.rerun = true
return Accessors.@set sampler.selector.rerun = true
end

# Interface we need a sampler to implement to work as a component in a Gibbs sampler.
Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ using DynamicPPL
using AbstractMCMC: AbstractModel, AbstractSampler
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DataStructures: OrderedSet
using Setfield: Setfield
using Accessors: Accessors

import ADTypes
import AbstractMCMC
import AdvancedHMC; const AHMC = AdvancedHMC
import AdvancedMH; const AMH = AdvancedMH
import AdvancedPS
import BangBang
import Accessors
import EllipticalSliceSampling
import LogDensityProblems
import LogDensityProblemsAD
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.pa
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))

setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Setfield.@set f.varinfo = varinfo
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) = setvarinfo(parent(f), varinfo)

# TODO: Do we also support `resume`, etc?
Expand Down
4 changes: 2 additions & 2 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Bijectors
using Random
using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType, NoAD

using Setfield
using Accessors: Accessors
using DynamicPPL
using DynamicPPL: Model, AbstractContext, VarInfo, VarName,
_getindex, getsym, getfield, setorder!,
Expand Down Expand Up @@ -150,7 +150,7 @@ function transform!!(f::OptimLogDensity)
linked = DynamicPPL.istrans(f.varinfo)

## transform into constrained or unconstrained space depending on current state of vi
@set! f.varinfo = if !linked
f = Accessors.@set f.varinfo = if !linked
DynamicPPL.link!!(f.varinfo, f.model)
else
DynamicPPL.invlink!!(f.varinfo, f.model)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.24"
DynamicPPL = "0.25.1"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down
20 changes: 12 additions & 8 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,18 @@
end
end

@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
if VERSION ≥ v"1.8"
@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) ≈ 0.2
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) ≈ 0.2
end
end
34 changes: 19 additions & 15 deletions test/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,28 @@
# @test v1 < v2
end

@turing_testset "vector of multivariate distributions" begin
@model function test(k)
T = Vector{Vector{Float64}}(undef, k)
for i in 1:k
T[i] ~ Dirichlet(5, 1.0)
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
if VERSION ≥ v"1.8"
@turing_testset "vector of multivariate distributions" begin
@model function test(k)
T = Vector{Vector{Float64}}(undef, k)
for i in 1:k
T[i] ~ Dirichlet(5, 1.0)
end
end
end

Random.seed!(100)
chain = sample(test(1), MH(), 5_000)
for i in 1:5
@test mean(chain, "T[1][$i]") ≈ 0.2 atol=0.01
end
Random.seed!(100)
chain = sample(test(1), MH(), 5_000)
for i in 1:5
@test mean(chain, "T[1][$i]") ≈ 0.2 atol = 0.01
end

Random.seed!(100)
chain = sample(test(10), MH(), 5_000)
for j in 1:10, i in 1:5
@test mean(chain, "T[$j][$i]") ≈ 0.2 atol=0.01
Random.seed!(100)
chain = sample(test(10), MH(), 5_000)
for j in 1:10, i in 1:5
@test mean(chain, "T[$j][$i]") ≈ 0.2 atol = 0.01
end
end
end

Expand Down
Loading