From 72b751beea0d8bda61ad9fb6d1d09cf29cfa9ceb Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 26 Mar 2019 00:13:43 -0700 Subject: [PATCH] Work towards fixing minipyro --- funsor/minipyro.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/funsor/minipyro.py b/funsor/minipyro.py index 837264b51..9f570b3dd 100644 --- a/funsor/minipyro.py +++ b/funsor/minipyro.py @@ -55,7 +55,7 @@ def __enter__(self): # trace illustrates why we need postprocess in addition to process: # We only want to record a value after all other effects have been applied - @dispatch(Sample) + @dispatch((Sample, Param)) def postprocess(self, msg): assert msg["name"] is not None and \ msg["name"] not in self.trace, \ @@ -129,7 +129,7 @@ def __iter__(self): return range(self.size) -def sample(fn, obs=None, name=None): +def sample(name, fn, obs=None): """ This is an effectful version of ``Distribution.sample(...)``. When any effect handlers are active, it constructs an initial message and calls @@ -155,7 +155,7 @@ def sample(fn, obs=None, name=None): return msg["value"] -def param(init_value=None, name=None): +def param(name, init_value=None): """ This is an effectful version of ``PARAM_STORE.setdefault``. When any effect handlers are active, it constructs an initial message and calls @@ -176,8 +176,10 @@ def fn(init_value): # Otherwise, we initialize a message... initial_msg = Param(**{ + "name": name, "fn": fn, "args": (init_value,), + "kwargs": {}, "value": None, }) @@ -399,6 +401,10 @@ def elbo(model, guide, *args, **kwargs): return -elbo # negate, for use as loss +def Trace_ELBO(*args, **kwargs): + return elbo + + class SVI(object): """ This is a unified interface for stochastic variational inference in Pyro. @@ -419,7 +425,7 @@ def step(self, *args, **kwargs): # further tracing occurs inside of `loss`. with trace() as param_capture: # We use block here to allow tracing to record parameters only. - with block(hide_fn=lambda msg: msg["type"] == "sample"): + with block(hide_fn=lambda msg: isinstance(msg, Sample)): loss = self.loss(self.model, self.guide, *args, **kwargs) # Differentiate the loss. loss.backward()