Skip to content

Commit

Permalink
Fixes for examples/vae.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Mar 24, 2019
1 parent f10e09b commit 39586ad
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
23 changes: 15 additions & 8 deletions examples/vae.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import argparse
import os
from collections import OrderedDict

import torch
Expand All @@ -14,6 +15,9 @@
import funsor.ops as ops
from funsor.domains import bint, reals

REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_PATH = os.path.join(REPO_PATH, 'data')


class Encoder(nn.Module):
def __init__(self):
Expand All @@ -23,6 +27,7 @@ def __init__(self):
self.fc22 = nn.Linear(400, 20)

def forward(self, image):
image = image.reshape(image.shape[:-2] + (-1,))
h1 = F.relu(self.fc1(image))
loc = self.fc21(h1)
scale = self.fc22(h1)
Expand All @@ -31,7 +36,7 @@ def forward(self, image):

class Decoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
super(Decoder, self).__init__()
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)

Expand All @@ -44,15 +49,16 @@ def main(args):
encoder = Encoder()
decoder = Decoder()

encode = funsor.function(reals(28, 28), (reals(20), reals(20)))(encoder)
decode = funsor.function(reals(20), reals(28, 28))(decoder)
encode = funsor.torch.function(reals(28, 28), (reals(20), reals(20)))(encoder)
decode = funsor.torch.function(reals(20), reals(28, 28))(decoder)

@funsor.interpreter.interpretation(funsor.terms.monte_carlo)
@funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo)
def loss_function(data, scale):
loc, scale = encode(data)
i = funsor.Variable('i', bint(20))
z = funsor.Variable('z', reals(20))
q = dist.Normal(loc[i], scale[i], value=z[i])
assert isinstance(q, funsor.gaussian.Gaussian), q
q = q.reduce(ops.add, frozenset(['i']))

probs = decode(z)
Expand All @@ -67,18 +73,19 @@ def loss_function(data, scale):
return loss.data

train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
datasets.MNIST(DATA_PATH, train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True)

encoder.train()
decoder.train()
optimizer = optim.Adam(encoder.parameters() +
decoder.parameters(), lr=1e-3)
optimizer = optim.Adam(list(encoder.parameters()) +
list(decoder.parameters()), lr=1e-3)
for epoch in range(args.num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
scale = float(len(train_loader.dataset) / len(data))
data = data[:, 0, :, :]
data = funsor.Tensor(data, OrderedDict(batch=bint(len(data))))

optimizer.zero_grad()
Expand All @@ -94,7 +101,7 @@ def loss_function(data, scale):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('-n', '--num-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=8)
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(self, offset):
assert offset >= 0
self.offset = offset
self._prefix = (slice(None),) * offset
self.__name__ = 'GetitemOp({})'.format(offset)
super(GetitemOp, self).__init__(self._default)
self.__name__ = 'GetitemOp({})'.format(offset)

def _default(self, x, y):
return x[self._prefix + (y,)] if self.offset else x[y]
Expand Down
5 changes: 4 additions & 1 deletion funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,10 @@ def __name__(self):


def _function(inputs, output, fn):
names = getargspec(fn)[0]
if isinstance(fn, torch.nn.Module):
names = getargspec(fn.forward)[0][1:]
else:
names = getargspec(fn)[0]
args = tuple(Variable(name, domain) for (name, domain) in zip(names, inputs))
assert len(args) == len(inputs)
if not isinstance(output, Domain):
Expand Down

0 comments on commit 39586ad

Please sign in to comment.