Skip to content

Commit

Permalink
First pass at fixing minipyro
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Mar 27, 2019
1 parent 46113dd commit d13ce0a
Showing 1 changed file with 86 additions and 44 deletions.
130 changes: 86 additions & 44 deletions funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import torch

import funsor

# Pyro keeps track of two kinds of global state:
# i) The effect handler stack, which enables non-standard interpretations of
# Pyro primitives like sample();
Expand Down Expand Up @@ -77,19 +79,33 @@ def get_trace(self, *args, **kwargs):
return self.trace


# A second example of an effect handler for setting the value at a sample site.
# This illustrates why effect handlers are a useful PPL implementation technique:
# We can compose trace and replay to replace values but preserve distributions,
# allowing us to compute the joint probability density of samples under a model.
# See the definition of elbo(...) below for an example of this pattern.
class replay(Messenger):
def __init__(self, fn, guide_trace):
self.guide_trace = guide_trace
super(replay, self).__init__(fn)
class log_joint(Messenger):
def __enter__(self):
super(log_joint, self).__enter__()
self.log_factors = OrderedDict()
self.plates = set()
return self

def process_message(self, msg):
if msg["name"] in self.guide_trace:
msg["value"] = self.guide_trace[msg["name"]]["value"]
if msg["type"] != "sample":
return None
msg["value"] = funsor.Variable(msg["name"], msg["fn"].inputs["value"])

def postprocess_message(self, msg):
if msg["type"] != "sample":
return None
assert msg["name"] not in self.log_factors, "all sites must have unique names"
log_prob = msg["fn"](value=msg["value"])
self.log_factors[msg["name"]] = log_prob
self.plates.update(msg["cond_indep_stack"].values()) # maps dim to name

def contract(self):
return funsor.sum_product.sum_product(
sum_op=funsor.ops.logaddexp,
prod_op=funsor.ops.add,
factors=list(self.log_factors.values()),
plates=frozenset(self.plates),
eliminate=frozenset(self.log_factors.keys()))


# block allows the selective application of effect handlers to different parts of a model.
Expand All @@ -107,19 +123,43 @@ def process_message(self, msg):

# This limited implementation of PlateMessenger only implements broadcasting.
class PlateMessenger(Messenger):
def __init__(self, fn, size, dim):
def __init__(self, fn, size, dim, name):
assert dim < 0
self.size = size
self.dim = dim
self.name = name
super(PlateMessenger, self).__init__(fn)

def process_message(self, msg):
if msg["type"] == "sample":
batch_shape = msg["fn"].batch_shape
if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size:
batch_shape = [1] * (-self.dim - len(batch_shape)) + list(batch_shape)
batch_shape[self.dim] = self.size
msg["fn"] = msg["fn"].expand(torch.Size(batch_shape))
assert self.dim not in msg["cond_indep_stack"]
msg["cond_indep_stack"][self.dim] = self.name

if msg["value"] is not None:
value = msg["value"]
if not isinstance(value, funsor.Funsor):
assert isinstance(value, torch.Tensor)
output = msg["fn"].inputs["value"]
event_shape = output.shape
batch_shape = value.shape[:value.dim() - len(event_shape)]
inputs = OrderedDict()
data = value
for dim, size in enumerate(batch_shape):
if size == 1:
data = data.squeeze(dim)
else:
name = msg["cond_indep_stack"][dim - len(batch_shape)]
inputs[name] = funsor.bint(size)
value = funsor.torch.Tensor(data, inputs, output.dtype)
assert value.output == output
msg["value"] = value

# TODO expand function
# batch_shape = msg["fn"].batch_shape
# if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size:
# batch_shape = [1] * (-self.dim - len(batch_shape)) + list(batch_shape)
# batch_shape[self.dim] = self.size
# msg["fn"] = msg["fn"].expand(torch.Size(batch_shape))

def __iter__(self):
return range(self.size)
Expand Down Expand Up @@ -151,7 +191,8 @@ def sample(name, fn, obs=None):

# if there are no active Messengers, we just draw a sample and return it as expected:
if not PYRO_STACK:
return fn()
raise NotImplementedError('Funsor cannot sample')
# return fn()

# Otherwise, we initialize a message...
initial_msg = {
Expand All @@ -160,6 +201,7 @@ def sample(name, fn, obs=None):
"fn": fn,
"args": (),
"value": obs,
"cond_indep_stack": {}, # maps dim to name
}

# ...and use apply_stack to send it to the Messengers
Expand Down Expand Up @@ -196,7 +238,7 @@ def fn(init_value):

# boilerplate to match the syntax of actual pyro.plate:
def plate(name, size, dim):
return PlateMessenger(fn=None, size=size, dim=dim)
return PlateMessenger(fn=None, size=size, dim=dim, name=name)


# This is a thin wrapper around the `torch.optim.Adam` class that
Expand Down Expand Up @@ -263,29 +305,29 @@ def step(self, *args, **kwargs):
# random variablbes with reparameterized samplers), but all the ELBO
# implementations in Pyro share the same basic logic.
def elbo(model, guide, *args, **kwargs):
# Run the guide with the arguments passed to SVI.step() and trace the execution,
# i.e. record all the calls to Pyro primitives like sample() and param().
guide_trace = trace(guide).get_trace(*args, **kwargs)
# Now run the model with the same arguments and trace the execution. Because
# model is being run with replay, whenever we encounter a sample site in the
# model, instead of sampling from the corresponding distribution in the model,
# we instead reuse the corresponding sample from the guide. In probabilistic
# terms, this means our loss is constructed as an expectation w.r.t. the joint
# distribution defined by the guide.
model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs)
# We will accumulate the various terms of the ELBO in `elbo`.
elbo = 0.
# Loop over all the sample sites in the model and add the corresponding
# log p(z) term to the ELBO. Note that this will also include any observed
# data, i.e. sample sites with the keyword `obs=...`.
for site in model_trace.values():
if site["type"] == "sample":
elbo = elbo + site["fn"].log_prob(site["value"]).sum()
# Loop over all the sample sites in the guide and add the corresponding
# -log q(z) term to the ELBO.
for site in guide_trace.values():
if site["type"] == "sample":
elbo = elbo - site["fn"].log_prob(site["value"]).sum()
# Return (-elbo) since by convention we do gradient descent on a loss and
# the ELBO is a lower bound that needs to be maximized.
with log_joint() as guide_log_joint:
guide(*args, **kwargs)
with log_joint() as model_log_joint:
model(*args, **kwargs)
plates = frozenset(guide_log_joint.plates | model_log_joint.plates)
eliminate = frozenset().update(guide_log_joint.log_factors(),
model_log_joint.log_factors())
factors = []
for p in model_log_joint.log_factors.values():
factors.append(p)
for q in guide_log_joint.log_factors.values():
factors.append(-q)
factors.append(q.sample(eliminate)).exp()

elbo = funsor.sum_product.sum_product(
sum_op=funsor.ops.logaddexp,
prod_op=funsor.ops.add,
factors=factors,
plates=plates,
eliminate=eliminate)

return -elbo


def Trace_ELBO(*args, **kwargs):
return elbo

0 comments on commit d13ce0a

Please sign in to comment.