Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusts for the fact that IRMH proposal might not be symmetric #581

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,28 @@ def kernel(
state: RWState,
logdensity_fn: Callable,
proposal_distribution: Callable,
proposal_logdensity_fn: Optional[Callable] = None,
) -> tuple[RWState, RWInfo]:
"""

Parameters
----------
proposal_distribution
A function that, given a PRNGKey, is able to produce a sample in the same
domain of the target distribution.
proposal_logdensity_fn:
For non-symmetric proposals, a function that returns the log-density
to obtain a given proposal knowing the current state. If it is not
provided we assume the proposal is symmetric.
"""

def proposal_generator(rng_key: PRNGKey, position: ArrayTree):
del position
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, just out of curiosity why is that del needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed, mostly it is styling to signify the input arg is not needed.

return proposal_distribution(rng_key)

inner_kernel = build_rmh()
return inner_kernel(rng_key, state, logdensity_fn, proposal_generator)
return inner_kernel(
rng_key, state, logdensity_fn, proposal_generator, proposal_logdensity_fn
)

return kernel

Expand Down Expand Up @@ -318,7 +324,10 @@ class irmh:
proposal_distribution
A Callable that takes a random number generator and produces a new proposal. The
proposal is independent of the sampler's current state.

proposal_logdensity_fn:
For non-symmetric proposals, a function that returns the log-density
to obtain a given proposal knowing the current state. If it is not
provided we assume the proposal is symmetric.
Returns
-------
A ``SamplingAlgorithm``.
Expand All @@ -332,14 +341,21 @@ def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
proposal_distribution: Callable,
proposal_logdensity_fn: Optional[Callable] = None,
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, proposal_distribution)
return kernel(
rng_key,
state,
logdensity_fn,
proposal_distribution,
proposal_logdensity_fn,
)

return SamplingAlgorithm(init_fn, step_fn)

Expand Down
53 changes: 40 additions & 13 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,58 @@ def test_logdensity_accepts(position):


class IRMHTest(unittest.TestCase):
def proposal_distribution(self, key):
return jnp.array([10.0])

def logdensity_accepts(self, position):
"""
a logdensity that gets maximized after the step
"""
return 0.0 if all(position - 10.0 < 1e-10) else 0.5

def test_proposal_is_independent_of_position(self):
"""New position does not depend on previous"""
"""New position does not depend on previous position"""
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])
other_position = jnp.array([15000.0])

def proposal_distribution(key):
return jnp.array([10.0])

def test_logdensity_accepts(position):
"""
a logdensity that gets maximized after the step
"""
return 0.0 if all(position - 10.0 < 1e-10) else 0.5

step = build_irmh()

for previous_position in [initial_position, other_position]:
new_state, _ = step(
new_state, state_info = step(
rng_key,
RWState(position=previous_position, logdensity=1.0),
test_logdensity_accepts,
proposal_distribution,
self.logdensity_accepts,
self.proposal_distribution,
)
np.testing.assert_allclose(new_state.position, jnp.array([10.0]))
np.testing.assert_allclose(state_info.acceptance_rate, 0.367879, rtol=1e-5)

def test_non_symmetric_proposal(self):
"""
Given that proposal_logdensity_fn is included,
thus the proposal is non-symmetric.
When computing the acceptance of the proposed state
Then proposal_logdensity_fn value is taken into account
"""
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])

def test_proposal_logdensity(new_state, prev_state):
return 0.1 if all(new_state.position - 10 < 1e-10) else 0.5

step = build_irmh()

for previous_position in [initial_position]:
_, state_info = step(
rng_key,
RWState(position=previous_position, logdensity=1.0),
self.logdensity_accepts,
self.proposal_distribution,
test_proposal_logdensity,
)

np.testing.assert_allclose(state_info.acceptance_rate, 0.246597)


class RMHProposalTest(unittest.TestCase):
Expand Down