diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9eb4d9675..914c0e12b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -75,6 +75,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, diff --git a/src/compiler.jl b/src/compiler.jl index 2e368d32b..352d46418 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -286,11 +286,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -304,9 +300,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -316,7 +310,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -337,11 +330,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -355,9 +344,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -367,7 +354,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -411,7 +395,9 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + $(modelinfo[:body]) + end ## Build the model function. @@ -449,8 +435,12 @@ end """ matchingvalue(sampler, vi, value) + matchingvalue(context::AbstractContext, vi, value) + +Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. -Convert the `value` to the correct type for the `sampler` and the `vi` object. +For a `context` that is _not_ a `SamplingContext`, we fall back to +`matchingvalue(SampleFromPrior(), vi, value)`. """ function matchingvalue(sampler, vi, value) T = typeof(value) @@ -467,6 +457,13 @@ function matchingvalue(sampler, vi, value) end matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(SampleFromPrior(), vi, value) +end +function matchingvalue(context::SamplingContext, vi, value) + return matchingvalue(context.sampler, vi, value) +end + """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 60df298b5..77cbc0fb2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,86 +18,194 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +``` +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +end + +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(context, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. + +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + return tilde_observe( + context.rng, context.context, context.sampler, right, left, vname, vinds, vi + ) +end + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. + +Falls back to `tilde_observe(context.context, right, left, vi)`. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + return tilde_observe(context.context, context.sampler, right, left, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(context, right, left, vname, vinds, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `tilde(context, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -110,14 +218,28 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + vi, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) + r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) @@ -125,79 +247,187 @@ function assume( r = vi[vn] end else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) + r = init(rng, dist, sampler) + push!(vi, vn, r, dist, sampler) settrans!(vi, false, vn) end + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +``` +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + return dot_tilde_assume( + context.rng, context.context, context.sampler, right, left, vn, inds, vi + ) +end + +# `DefaultContext` +function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) + return dot_assume(right, left, vns, vi) +end + +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + return dot_assume(right, left, vn, vi) +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + return dot_assume(rng, sampler, right, vn, left, vi) +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function dot_tilde_assume( + rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) +end + +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) + return dot_tilde_assume( + rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + ) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + lp = sum(zip(vns, eachcol(r))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return r, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -211,6 +441,24 @@ function dot_assume( lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) return r, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -319,84 +567,111 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.context, context.sampler, right, left, vi) +end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 +dot_tilde_observe(::PriorContext, right, left, vi) = 0 +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) + return dot_observe(right, left, vi) end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `MiniBatchContext` +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(context, right, left, vn, inds, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi, ) + return dot_observe(dist, value, vi) +end +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi, ) + return dot_observe(dists, value, vi) +end +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi, ) + return dot_observe(dists, value, vi) +end +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end diff --git a/src/contexts.jl b/src/contexts.jl index 2c23531c6..8093c88f3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,17 @@ +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + """ struct DefaultContext <: AbstractContext end @@ -35,7 +49,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -46,31 +60,42 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..6fca717c6 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,80 +1,102 @@ # Context version struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A - ctx::Ctx + context::Ctx end function PointwiseLikelihoodContext( - likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), + context::AbstractContext=LikelihoodContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[string(vn)] = logp + return context.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::String, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!(context.context, right, left, vi) +end +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - return value + + # Track loglikelihood value. + push!(context, vn, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) - # This is slightly unfortunate since it is not completely generic... - # Ideally we would call `tilde_observe` recursively but then we don't get the - # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe(context.context, right, left, vi) +end +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - # track loglikelihood value - push!(ctx, vname, logp) + # Track loglikelihood value. + push!(context, vn, logp) return left end @@ -150,30 +172,29 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) + context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, context) end niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in ctx.loglikelihoods + (varname, logliks) in context.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) - return ctx.loglikelihoods + context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}()) + model(varinfo, context) + return context.loglikelihoods end diff --git a/src/model.jl b/src/model.jl index 7189b590e..2d74949c1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(varinfo, SamplingContext(rng, sampler, context)) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,27 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return quote + sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + model.f(model, varinfo, context, $(unwrap_args...)) + end end """ @@ -183,7 +192,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +204,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +216,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 92584ae8b..1d574e286 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,22 +1,14 @@ macro submodel(expr) return quote - _evaluate( - $(esc(:__rng__)), - $(esc(expr)), - $(esc(:__varinfo__)), - $(esc(:__sampler__)), - $(esc(:__context__)), - ) + _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) end end macro submodel(prefix, expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,6 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..d219f91ea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -172,10 +172,10 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end @@ -184,18 +184,17 @@ end @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model - @test sampler_ === SampleFromPrior() - @test context_ === DefaultContext() + @test context_ isa SamplingContext @test rng_ isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end false diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl new file mode 100644 index 000000000..74fb88d70 --- /dev/null +++ b/test/loglikelihoods.jl @@ -0,0 +1,123 @@ +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x=10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x=10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + return x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + return x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const gdemo_models = ( + gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() +) + +@testset "loglikelihoods.jl" begin + for m in gdemo_models + vi = VarInfo(m) + + vns = vi.metadata.m.vns + if length(vns) == 1 && length(vi[vns[1]]) == 1 + # Only have one latent variable. + DynamicPPL.setval!(vi, [1.0], ["m"]) + else + DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + end + + lls = pointwise_loglikelihoods(m, vi) + + if isempty(lls) + # One of the models with literal observations, so we just skip. + continue + end + + loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + # Only have one observation, so we need to double it + # for comparison with other models. + 2 * sum(lls[first(keys(lls))]) + else + sum(sum, values(lls)) + end + + @test loglikelihood ≈ -324.45158270528947 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3d5d55c..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,8 @@ include("test_util.jl") include("threadsafe.jl") include("serialization.jl") + + include("loglikelihoods.jl") end @testset "compat" begin diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 746d6a5f8..83c53ccd6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,18 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @model function wothreads(x) @@ -96,14 +100,18 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) end end