diff --git a/funsor/delta.py b/funsor/delta.py index 5102d384f..5a28ccc3f 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -13,12 +13,11 @@ class DeltaMeta(FunsorMeta): """ - Wrapper to fill in defaults. + Wrapper to convert point to a funsor. """ - def __call__(cls, name, point, log_density=0): + def __call__(cls, name, point): point = to_funsor(point) - log_density = to_funsor(log_density) - return super(DeltaMeta, cls).__call__(name, point, log_density) + return super(DeltaMeta, cls).__call__(name, point) @add_metaclass(DeltaMeta) @@ -28,23 +27,16 @@ class Delta(Funsor): :param str name: Name of the bound variable. :param Funsor point: Value of the bound variable. - :param Funsor log_density: Optional log density to be added when evaluating - at a point. This is needed to make :class:`Delta` closed under - differentiable substitution. """ - def __init__(self, name, point, log_density=0): + def __init__(self, name, point): assert isinstance(name, str) assert isinstance(point, Funsor) - assert isinstance(log_density, Funsor) - assert log_density.output == reals() inputs = OrderedDict([(name, point.output)]) inputs.update(point.inputs) - inputs.update(log_density.inputs) output = reals() super(Delta, self).__init__(inputs, output) self.name = name self.point = point - self.log_density = log_density def eager_subs(self, subs): value = None @@ -62,17 +54,16 @@ def eager_subs(self, subs): name = self.name point = self.point.eager_subs(index_part) - log_density = self.log_density.eager_subs(index_part) if value is not None: if isinstance(value, Variable): name = value.name elif isinstance(value, (Number, Tensor)) and isinstance(point, (Number, Tensor)): - return (value == point).all().log() + log_density + return (value == point).all().log() else: # TODO Compute a jacobian, update log_prob, and emit another Delta. raise ValueError('Cannot substitute a {} into a Delta' .format(type(value).__name__)) - return Delta(name, point, log_density) + return Delta(name, point) def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: diff --git a/funsor/distributions.py b/funsor/distributions.py index 6c65412b3..4469af801 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -151,17 +151,16 @@ def eager_delta(v, log_density, value): return Tensor(data, inputs) -@eager.register(Delta, Funsor, Funsor, Variable) -@eager.register(Delta, Variable, Funsor, Variable) +@eager.register(Delta, (Funsor, Variable), Funsor, Variable) def eager_delta(v, log_density, value): assert v.output == value.output - return funsor.delta.Delta(value.name, v, log_density) + return funsor.delta.Delta(value.name, v) + log_density @eager.register(Delta, Variable, Funsor, Funsor) def eager_delta(v, log_density, value): assert v.output == value.output - return funsor.delta.Delta(v.name, value, log_density) + return funsor.delta.Delta(v.name, value) + log_density class Normal(Distribution): diff --git a/test/test_distributions.py b/test/test_distributions.py index 622a0678a..bb656b3f8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -86,7 +86,7 @@ def test_delta_delta(): point = Tensor(torch.randn(2)) log_density = Tensor(torch.tensor(0.5)) d = dist.Delta(point, log_density, v) - assert d is Delta('v', point, log_density) + assert d is Delta('v', point) + log_density def test_normal_defaults():