Skip to content

Commit

Permalink
Fix reference cycle
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Mar 3, 2019
1 parent 3a5d3a3 commit 5b038ce
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
10 changes: 4 additions & 6 deletions examples/neural_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@ def __init__(self, f, approx_fix_f):
self.approx_fix_f = approx_fix_f
self.buffers.loss = 0 # FIXME

def approx_f(_, *args):
return self.approx_fix_f(*args)

self.approx_f = approx_f
def approx_f(self, fn, *args):
return self.approx_fix_f(*args)

def forward(self, *args):
# TODO Unroll to multiple depths, gathering loss at each depth.
value0 = self.approx_fix_f(None, *args)
value1 = self.f(self.approx_f, *args)
self.buffers.loss = (value0 - value1).abs().sum()
value = 0.5 * (value0 + value1)
return value
return value1


def main(args):
Expand Down
10 changes: 5 additions & 5 deletions funsor/fixpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
def truncate(fn, output):
TRUNCATION_DEPTHS[fn] += 1
if TRUNCATION_DEPTHS[fn] <= MAX_TRUNCATION_DEPTH:
result = fn
approx_fn = fn
else:
def result(*args, **kwargs):
def approx_fn(fn, *args):
return Tensor(torch.tensor(float('nan')).expand(output.shape))
try:
yield result
yield approx_fn
finally:
TRUNCATION_DEPTHS[fn] -= 1

Expand Down Expand Up @@ -74,8 +74,8 @@ def eager_subs(self, subs):

@truncated.register(Fix, object, Domain, tuple)
def eager_function(fn, output, args):
with truncate(fn) as t:
return fn(t, *args)
with truncate(fn) as approx_fn:
return fn(approx_fn, *args)


def _fix(inputs, output, fn):
Expand Down

0 comments on commit 5b038ce

Please sign in to comment.