Skip to content

Commit

Permalink
fix loc scale reparam with center=1 (#1059)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jun 8, 2021
1 parent 4ad6b98 commit f212f32
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, name, fn, obs):
assert obs is None, "LocScaleReparam does not support observe statements"
centered = self.centered
if is_identically_one(centered):
return name, fn, obs
return fn, obs
event_shape = fn.event_shape
fn, expand_shape, event_dim = self._unwrap(fn)

Expand Down
4 changes: 2 additions & 2 deletions test/infer/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ def get_expected_probe(loc, scale):
return get_moments(trace["x"]["value"])

if "dist_type" == "Normal":
reparam = LocScaleReparam()
reparam = LocScaleReparam(centered)
else:
reparam = LocScaleReparam(shape_params=["df"])
reparam = LocScaleReparam(centered, shape_params=["df"])

def get_actual_probe(loc, scale):
with numpyro.handlers.trace() as trace:
Expand Down

0 comments on commit f212f32

Please sign in to comment.