Skip to content

Commit

Permalink
Try #259:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jun 9, 2021
2 parents ef6da43 + fa6c4d6 commit 7b948c6
Show file tree
Hide file tree
Showing 13 changed files with 650 additions and 195 deletions.
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export AbstractVarInfo,
SampleFromPrior,
SampleFromUniform,
# Contexts
SamplingContext,
DefaultContext,
LikelihoodContext,
PriorContext,
Expand Down
37 changes: 17 additions & 20 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,7 @@ function generate_tilde(left, right)
if !(left isa Symbol || left isa Expr)
return quote
$(DynamicPPL.tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
Expand All @@ -304,9 +300,7 @@ function generate_tilde(left, right)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
__rng__,
__context__,
__sampler__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
Expand All @@ -316,7 +310,6 @@ function generate_tilde(left, right)
else
$(DynamicPPL.tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
Expand All @@ -337,11 +330,7 @@ function generate_dot_tilde(left, right)
if !(left isa Symbol || left isa Expr)
return quote
$(DynamicPPL.dot_tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
Expand All @@ -355,9 +344,7 @@ function generate_dot_tilde(left, right)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__rng__,
__context__,
__sampler__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
Expand All @@ -367,7 +354,6 @@ function generate_dot_tilde(left, right)
else
$(DynamicPPL.dot_tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
Expand Down Expand Up @@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__rng__::$(Random.AbstractRNG)),
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__sampler__::$(DynamicPPL.AbstractSampler)),
:(__context__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
Expand All @@ -411,7 +395,9 @@ function build_output(modelinfo, linenumbernode)
evaluatordef[:kwargs] = []

# Replace the user-provided function body with the version created by DynamicPPL.
evaluatordef[:body] = modelinfo[:body]
evaluatordef[:body] = quote
$(modelinfo[:body])
end

## Build the model function.

Expand Down Expand Up @@ -449,8 +435,12 @@ end

"""
matchingvalue(sampler, vi, value)
matchingvalue(context::AbstractContext, vi, value)
Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
Convert the `value` to the correct type for the `sampler` and the `vi` object.
For a `context` that is _not_ a `SamplingContext`, we fall back to
`matchingvalue(SampleFromPrior(), vi, value)`.
"""
function matchingvalue(sampler, vi, value)
T = typeof(value)
Expand All @@ -467,6 +457,13 @@ function matchingvalue(sampler, vi, value)
end
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)

function matchingvalue(context::AbstractContext, vi, value)
return matchingvalue(SampleFromPrior(), vi, value)
end
function matchingvalue(context::SamplingContext, vi, value)
return matchingvalue(context.sampler, vi, value)
end

"""
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
Expand Down
Loading

0 comments on commit 7b948c6

Please sign in to comment.