Skip to content

Commit

Permalink
Adding minimal test
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Oct 31, 2023
1 parent 58b9999 commit 9138a86
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 15 deletions.
5 changes: 3 additions & 2 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
new_state, info = step(rng_key, state)
"""
from typing import Callable, NamedTuple, Optional, Tuple
from typing import Callable, NamedTuple, Optional

import jax
from jax import numpy as jnp
Expand Down Expand Up @@ -271,7 +271,7 @@ def kernel(
logdensity_fn: Callable,
proposal_distribution: Callable,
proposal_logdensity_fn: Optional[Callable] = None,
) -> Tuple[RWState, RWInfo]:
) -> tuple[RWState, RWInfo]:
"""
Parameters
----------
Expand All @@ -285,6 +285,7 @@ def kernel(
"""

def proposal_generator(rng_key: PRNGKey, position: ArrayTree):
del position
return proposal_distribution(rng_key)

inner_kernel = build_rmh()
Expand Down
53 changes: 40 additions & 13 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,58 @@ def test_logdensity_accepts(position):


class IRMHTest(unittest.TestCase):
def proposal_distribution(self, key):
return jnp.array([10.0])

def logdensity_accepts(self, position):
"""
a logdensity that gets maximized after the step
"""
return 0.0 if all(position - 10.0 < 1e-10) else 0.5

def test_proposal_is_independent_of_position(self):
"""New position does not depend on previous"""
"""New position does not depend on previous position"""
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])
other_position = jnp.array([15000.0])

def proposal_distribution(key):
return jnp.array([10.0])

def test_logdensity_accepts(position):
"""
a logdensity that gets maximized after the step
"""
return 0.0 if all(position - 10.0 < 1e-10) else 0.5

step = build_irmh()

for previous_position in [initial_position, other_position]:
new_state, _ = step(
new_state, state_info = step(
rng_key,
RWState(position=previous_position, logdensity=1.0),
test_logdensity_accepts,
proposal_distribution,
self.logdensity_accepts,
self.proposal_distribution,
)
np.testing.assert_allclose(new_state.position, jnp.array([10.0]))
np.testing.assert_allclose(state_info.acceptance_rate, 0.367879, rtol=1e-5)

def test_non_symmetric_proposal(self):
"""
Given that proposal_logdensity_fn is included,
thus the proposal is non-symmetric.
When computing the acceptance of the proposed state
Then proposal_logdensity_fn value is taken into account
"""
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])

def test_proposal_logdensity(new_state, prev_state):
return 0.1 if all(new_state.position - 10 < 1e-10) else 0.5

step = build_irmh()

for previous_position in [initial_position]:
_, state_info = step(
rng_key,
RWState(position=previous_position, logdensity=1.0),
self.logdensity_accepts,
self.proposal_distribution,
test_proposal_logdensity,
)

np.testing.assert_allclose(state_info.acceptance_rate, 0.246597)


class RMHProposalTest(unittest.TestCase):
Expand Down

0 comments on commit 9138a86

Please sign in to comment.