Skip to content

Commit

Permalink
Partial posteriors SMC and refactor to decouple tempering from SMC co…
Browse files Browse the repository at this point in the history
…nstruction (blackjax-devs#729)

* extracting taking last

* test passing

* layering

* example

* more

* Adding another example

* tests in place

* rolling back changes

* Adding test for num_mcmc_steps

* format

* better test coverage

* linter

* Flake8

* black

* implementation[

* partial posteriors implementation

* rolling back some changes

* linter

* fixing test

* adding reference

* typo

* exposing in top level api

* reruning precommit

* adding more steps

* smaller step size

* fixes on comments

* small fix on formating

* renaming to data mask

* linter
  • Loading branch information
aphc14 committed Oct 19, 2024
1 parent bc08239 commit 1dddd47
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 59 deletions.
3 changes: 2 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def generate_top_level_api_from(module):
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

smc_family = [tempered_smc, adaptive_tempered_smc]
smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
Expand Down
1 change: 1 addition & 0 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
"tempered",
"inner_kernel_tuning",
"extend_params",
"partial_posteriors_path",
]
28 changes: 28 additions & 0 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,31 @@ def extend_params(params):
"""

return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles
64 changes: 64 additions & 0 deletions blackjax/smc/from_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from functools import partial
from typing import Callable

import jax

from blackjax import smc
from blackjax.smc.base import SMCState, update_and_take_last
from blackjax.types import PRNGKey


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
update_strategy: Callable = update_and_take_last,
):
"""SMC step from MCMC kernels.
Builds MCMC kernels from the input parameters, which may change across iterations.
Moreover, it defines the way such kernels are used to update the particles. This layer
adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API
that depends on an update function over the set of particles.
Returns
-------
A callable that takes a rng_key and a state with .particles and .weights and returns a base.SMCState
and base.SMCInfo pair.
"""

def step(
rng_key: PRNGKey,
state,
num_mcmc_steps: int,
mcmc_parameters: dict,
logposterior_fn: Callable,
log_weights_fn: Callable,
) -> tuple[smc.base.SMCState, smc.base.SMCInfo]:
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

return smc.base.step(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
)

return step
127 changes: 127 additions & 0 deletions blackjax/smc/partial_posteriors_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp

from blackjax import SamplingAlgorithm, smc
from blackjax.smc.base import update_and_take_last
from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


class PartialPosteriorsSMCState(NamedTuple):
"""Current state for the tempered SMC algorithm.
particles: PyTree
The particles' positions.
weights:
Weights of the particles, so that they represent a probability distribution
data_mask:
A 1D boolean array to indicate which datapoints to include
in the computation of the observed likelihood.
"""

particles: ArrayTree
weights: Array
data_mask: Array


def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState:
"""num_datapoints are the number of observations that could potentially be
used in a partial posterior. Since the initial data_mask is all 0s, it
means that no likelihood term will be added (only prior).
"""
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
weights = jnp.ones(num_particles) / num_particles
return PartialPosteriorsSMCState(particles, weights, jnp.zeros(num_datapoints))


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
num_mcmc_steps: Optional[int],
mcmc_parameters: ArrayTree,
partial_logposterior_factory: Callable[[Array], Callable],
update_strategy=update_and_take_last,
) -> Callable:
"""Build the Partial Posteriors (data tempering) SMC kernel.
The distribution's trajectory includes increasingly adding more
datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936
Parameters
----------
mcmc_step_fn
A function that computes the log density of the prior distribution
mcmc_init_fn
A function that returns the probability at a given position.
resampling_fn
A random function that resamples generated particles based of weights
num_mcmc_steps
Number of iterations in the MCMC chain.
mcmc_parameters
A dictionary of parameters to be used by the inner MCMC kernels
partial_logposterior_factory:
A callable that given an array of 0 and 1, returns a function logposterior(x).
The array represents which values to include in the logposterior calculation. The logposterior
must be jax compilable.
Returns
-------
A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for
the current and previous posteriors, and takes a data-tempered SMC state.
"""
delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy)

def step(
key, state: PartialPosteriorsSMCState, data_mask: Array
) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]:
logposterior_fn = partial_logposterior_factory(data_mask)

previous_logposterior_fn = partial_logposterior_factory(state.data_mask)

def log_weights_fn(x):
return logposterior_fn(x) - previous_logposterior_fn(x)

state, info = delegate(
key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn
)

return (
PartialPosteriorsSMCState(state.particles, state.weights, data_mask),
info,
)

return step


def as_top_level_api(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: dict,
resampling_fn: Callable,
num_mcmc_steps,
partial_logposterior_factory: Callable,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""A factory that wraps the kernel into a SamplingAlgorithm object.
See build_kernel for full documentation on the parameters.
"""

kernel = build_kernel(
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
num_mcmc_steps,
mcmc_parameters,
partial_logposterior_factory,
update_strategy,
)

def init_fn(position: ArrayLikeTree, num_observations, rng_key=None):
del rng_key
return init(position, num_observations)

def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array):
return kernel(key, state, data_mask)

return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type]
66 changes: 11 additions & 55 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple, Optional

import jax
import jax.numpy as jnp

import blackjax.smc as smc
import blackjax.smc.from_mcmc as smc_from_mcmc
from blackjax.base import SamplingAlgorithm
from blackjax.smc.base import SMCState
from blackjax.smc.base import update_and_take_last
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"]
Expand Down Expand Up @@ -48,35 +48,6 @@ def init(particles: ArrayLikeTree):
return TemperedSMCState(particles, weights, 0.0)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""
Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
Expand Down Expand Up @@ -121,6 +92,9 @@ def build_kernel(
information about the transition.
"""
delegate = smc_from_mcmc.build_kernel(
mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy
)

def kernel(
rng_key: PRNGKey,
Expand Down Expand Up @@ -153,14 +127,6 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -169,23 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

smc_state, info = smc.base.step(
smc_state, info = delegate(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
state,
num_mcmc_steps,
mcmc_parameters,
tempered_logposterior_fn,
log_weights_fn,
)

tempered_state = TemperedSMCState(
Expand Down
Loading

0 comments on commit 1dddd47

Please sign in to comment.