Skip to content

Commit

Permalink
Work towards fixing minipyro
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Mar 26, 2019
1 parent 15b0c73 commit 72b751b
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __enter__(self):

# trace illustrates why we need postprocess in addition to process:
# We only want to record a value after all other effects have been applied
@dispatch(Sample)
@dispatch((Sample, Param))
def postprocess(self, msg):
assert msg["name"] is not None and \
msg["name"] not in self.trace, \
Expand Down Expand Up @@ -129,7 +129,7 @@ def __iter__(self):
return range(self.size)


def sample(fn, obs=None, name=None):
def sample(name, fn, obs=None):
"""
This is an effectful version of ``Distribution.sample(...)``. When any
effect handlers are active, it constructs an initial message and calls
Expand All @@ -155,7 +155,7 @@ def sample(fn, obs=None, name=None):
return msg["value"]


def param(init_value=None, name=None):
def param(name, init_value=None):
"""
This is an effectful version of ``PARAM_STORE.setdefault``. When any effect
handlers are active, it constructs an initial message and calls
Expand All @@ -176,8 +176,10 @@ def fn(init_value):

# Otherwise, we initialize a message...
initial_msg = Param(**{
"name": name,
"fn": fn,
"args": (init_value,),
"kwargs": {},
"value": None,
})

Expand Down Expand Up @@ -399,6 +401,10 @@ def elbo(model, guide, *args, **kwargs):
return -elbo # negate, for use as loss


def Trace_ELBO(*args, **kwargs):
return elbo


class SVI(object):
"""
This is a unified interface for stochastic variational inference in Pyro.
Expand All @@ -419,7 +425,7 @@ def step(self, *args, **kwargs):
# further tracing occurs inside of `loss`.
with trace() as param_capture:
# We use block here to allow tracing to record parameters only.
with block(hide_fn=lambda msg: msg["type"] == "sample"):
with block(hide_fn=lambda msg: isinstance(msg, Sample)):
loss = self.loss(self.model, self.guide, *args, **kwargs)
# Differentiate the loss.
loss.backward()
Expand Down

0 comments on commit 72b751b

Please sign in to comment.