Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds values_as_in_model #588

Merged
merged 13 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl
extract_priors
```

Safe extraction of realizations from a given [`AbstractVarInfo`](@ref) can be done using [`extract_realizations`](@ref).

```@docs
extract_realizations
```

```@docs
NamedDist
```
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export AbstractVarInfo,
getargnames,
generated_quantities,
extract_priors,
extract_realizations,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
186 changes: 186 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,189 @@ function fixed(context::FixedContext)
# precedence over decendants of `context`.
return merge(context.values, fixed(childcontext(context)))
end

"""
RealizationExtractorContext

A context that is used to extract realizations from a model.

This is particularly useful when working in unconstrained space, but one
wants to extract the realization of a model in a constrained space.

# Fields
$(TYPEDFIELDS)
"""
struct RealizationExtractorContext{T,C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
"child context"
context::C
end

RealizationExtractorContext(values) = RealizationExtractorContext(values, DefaultContext())
function RealizationExtractorContext(context::AbstractContext)
return RealizationExtractorContext(OrderedDict(), context)
end

NodeTrait(::RealizationExtractorContext) = IsParent()
childcontext(context::RealizationExtractorContext) = context.context
function setchildcontext(context::RealizationExtractorContext, child)
return RealizationExtractorContext(context.values, child)
end

function Base.push!(context::RealizationExtractorContext, vn::VarName, value)
return setindex!(context.values, copy(value), vn)
end

function broadcast_push!(context::RealizationExtractorContext, vns, values)
return push!.((context,), vns, values)
end

# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
function broadcast_push!(
context::RealizationExtractorContext, vns::AbstractVector, values::AbstractMatrix
)
for (vn, col) in zip(vns, eachcol(values))
push!(context, vn, col)
end
end

# `tilde_asssume`
function tilde_assume(context::RealizationExtractorContext, right, vn, vi)
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
# Save the value.
push!(context, vn, value)
# Save the value.
# Pass on.
return value, logp, vi
end
function tilde_assume(
rng::Random.AbstractRNG, context::RealizationExtractorContext, sampler, right, vn, vi
)
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
# Save the value.
push!(context, vn, value)
# Pass on.
return value, logp, vi
end

# `dot_tilde_assume`
function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)

# Save the value.
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::RealizationExtractorContext,
sampler,
right,
left,
vn,
vi,
)
value, logp, vi = dot_tilde_assume(
rng, childcontext(context), sampler, right, left, vn, vi
)
# Save the value.
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
end

"""
extract_realizations(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
extract_realizations(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])

Extract realizations from the `model` for a given `varinfo` through a evaluation of the model.

If no `varinfo` is provided, then this is effectively the same as
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).

More specifically, this method attempts to extract the realization _as seen in the model_.
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
space.

Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
of additional model evaluations.

# Arguments
- `model::Model`: model to extract realizations from.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.

# Examples

## When `VarInfo` fails

The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.

```jldoctest
julia> using Distributions, StableRNGs

julia> rng = StableRNG(42);

julia> @model function model_changing_support()
x ~ Bernoulli(0.5)
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
end;

julia> model = model_changing_support();

julia> # Construct initial type-stable `VarInfo`.
varinfo = VarInfo(rng, model);

julia> # Link it so it works in unconstrained space.
varinfo_linked = DynamicPPL.link(varinfo, model);

julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
# Flip `x` so we hit the other support of `y`.
θ = [!varinfo[@varname(x)], rand(rng)];

julia> # Update the `VarInfo` with the new values.
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);

julia> # Determine the expected support of `y`.
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
(0, 1)

julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);

julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
# used in the very first model evaluation, hence the support of `y`
# is not updated even though `x` has changed.
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
false

julia> # Approach 2: Extract realizations using `extract_realizations`.
# (✓) `extract_realizations` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ extract_realizations(model, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function extract_realizations(
model::Model,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
context = RealizationExtractorContext(context)
evaluate!!(model, varinfo, context)
return context.values
end
function extract_realizations(
rng::Random.AbstractRNG,
model::Model,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return extract_realizations(model, varinfo, SamplingContext(rng, context))
end
50 changes: 46 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
AbstractProbabilisticProgram
struct Model{
F,
argnames,
defaultnames,
missings,
Targs,
Tdefaults,
Ctx<:AbstractContext,
IsStatic<:Union{Val{false},Val{true}},
} <: AbstractProbabilisticProgram
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx
has_static_support::IsStatic

@doc """
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
Expand All @@ -49,9 +58,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
context::Ctx=DefaultContext(),
has_static_support::Union{Val{false},Val{true}}=Val{false}(),
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
f, args, defaults, context
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(isstatic)}(
f, args, defaults, context, has_static_support
)
end
end
Expand All @@ -78,6 +88,38 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k
return Model(f, args, NamedTuple(kwargs), context)
end

"""
has_static_support(model::Model)

Return `true` if `model` has static support.
"""
has_static_support(model::Model) = model.has_static_support isa Val{true}

"""
set_static_support(model::Model, isstatic::Val{true},Val{false})

Set `model` to have static support if `isstatic` is `true`, otherwise not.
"""
function set_static_support(model::Model, isstatic::Union{Val{true},Val{false}})
return Model{getmissings(model)}(
model.f, model.args, model.defaults, model.context, isstatic
)
end

"""
mark_as_static_support(model::Model)

Mark `model` as having static support.
"""
mark_as_static_support(model::Model) = set_static_support(model, Val{true}())

"""
mark_as_dynamic_support(model::Model)

Mark `model` as not having static support.
"""
mark_as_dynamic_support(model::Model) = set_static_support(model, Val{false}())

function contextualize(model::Model, context::AbstractContext)
return Model(model.f, model.args, model.defaults, context)
end
Expand Down
20 changes: 19 additions & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
]
@testset "$(model.f)" for model in models_to_test
vns = DynamicPPL.TestUtils.varnames(model)
example_values = DynamicPPL.TestUtils.rand(model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = filter(
is_typed_varinfo,
DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns),
Expand All @@ -375,4 +375,22 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end
end
end

@testset "extract_realizations" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
realizations = extract_realizations(model, varinfo)
# Ensure that all variables are found.
vns_found = collect(keys(realizations))
@test vns ∩ vns_found == vns ∪ vns_found
# Ensure that the values are the same.
for vn in vns
@test realizations[vn] == varinfo[vn]
end
end
end
end
end
Loading