Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial posteriors SMC and refactor to decouple tempering from SMC construction #729

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4e9880a
extracting taking last
ciguaran Jul 29, 2024
753b89b
test passing
ciguaran Jul 29, 2024
46676e2
layering
ciguaran Jul 29, 2024
1c26405
example
ciguaran Aug 6, 2024
6929d55
more
ciguaran Aug 6, 2024
e46e86c
Adding another example
ciguaran Aug 7, 2024
f0ce4ba
tests in place
ciguaran Aug 15, 2024
3576690
rolling back changes
ciguaran Aug 15, 2024
5c34177
Adding test for num_mcmc_steps
ciguaran Aug 15, 2024
8145cbb
format
ciguaran Aug 15, 2024
9550d5d
better test coverage
ciguaran Aug 16, 2024
02dccbe
Merge branch 'main' into ciguaran_waste_free_smc
ciguaran Aug 16, 2024
c06b6ab
linter
ciguaran Aug 16, 2024
3eb18cb
Merge branch 'ciguaran_waste_free_smc' of github.com:ciguaran/blackja…
ciguaran Aug 16, 2024
424599e
Flake8
ciguaran Aug 16, 2024
110b62e
black
ciguaran Aug 16, 2024
964ec95
implementation[
ciguaran Aug 21, 2024
93205f3
partial posteriors implementation
ciguaran Aug 27, 2024
c608b60
rolling back some changes
ciguaran Aug 27, 2024
359be77
linter
ciguaran Aug 27, 2024
cef4ef0
fixing test
ciguaran Aug 27, 2024
0942087
adding reference
ciguaran Aug 27, 2024
a9a4d0c
Merge branch 'main' into ciguaran_split_tempered_from_mcmc_construction
ciguaran Aug 27, 2024
1304f9f
typo
ciguaran Aug 27, 2024
aec2e51
exposing in top level api
ciguaran Aug 27, 2024
cededec
reruning precommit
ciguaran Aug 27, 2024
14919f2
adding more steps
ciguaran Sep 10, 2024
601b74a
smaller step size
ciguaran Sep 30, 2024
34d5ba3
Merge branch 'main' into ciguaran_split_tempered_from_mcmc_construction
ciguaran Sep 30, 2024
4d6089e
fixes on comments
ciguaran Oct 4, 2024
a5922ad
small fix on formating
ciguaran Oct 4, 2024
26d271e
renaming to data mask
ciguaran Oct 4, 2024
b87d8e4
linter
ciguaran Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
Expand Down Expand Up @@ -119,8 +120,9 @@ def generate_top_level_api_from(module):
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

smc_family = [tempered_smc, adaptive_tempered_smc]
smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
Expand Down
1 change: 1 addition & 0 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
"tempered",
"inner_kernel_tuning",
"extend_params",
"partial_posteriors_path",
]
28 changes: 28 additions & 0 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,31 @@ def extend_params(params):
"""

return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles
64 changes: 64 additions & 0 deletions blackjax/smc/from_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from functools import partial
from typing import Callable

import jax

from blackjax import smc
from blackjax.smc.base import SMCState, update_and_take_last
from blackjax.types import PRNGKey


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
update_strategy: Callable = update_and_take_last,
):
"""SMC step from MCMC kernels.
Builds MCMC kernels from the input parameters, which may change across iterations.
Moreover, it defines the way such kernels are used to update the particles. This layer
adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API
that depends on an update function over the set of particles.
Returns
-------
A callable that takes a rng_key and a state with .particles and .weights and returns a base.SMCState
and base.SMCInfo pair.

"""

def step(
rng_key: PRNGKey,
state,
num_mcmc_steps: int,
mcmc_parameters: dict,
logposterior_fn: Callable,
log_weights_fn: Callable,
) -> tuple[smc.base.SMCState, smc.base.SMCInfo]:
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

return smc.base.step(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
)

return step
127 changes: 127 additions & 0 deletions blackjax/smc/partial_posteriors_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp

from blackjax import SamplingAlgorithm, smc
from blackjax.smc.base import update_and_take_last
from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


class PartialPosteriorsSMCState(NamedTuple):
"""Current state for the tempered SMC algorithm.

particles: PyTree
The particles' positions.
weights:
Weights of the particles, so that they represent a probability distribution
data_mask:
A 1D boolean array to indicate which datapoints to include
in the computation of the observed likelihood.
"""

particles: ArrayTree
weights: Array
data_mask: Array


def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState:
"""num_datapoints are the number of observations that could potentially be
used in a partial posterior. Since the initial data_mask is all 0s, it
means that no likelihood term will be added (only prior).
"""
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
weights = jnp.ones(num_particles) / num_particles
return PartialPosteriorsSMCState(particles, weights, jnp.zeros(num_datapoints))


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
num_mcmc_steps: Optional[int],
mcmc_parameters: ArrayTree,
partial_logposterior_factory: Callable[[Array], Callable],
update_strategy=update_and_take_last,
) -> Callable:
"""Build the Partial Posteriors (data tempering) SMC kernel.
The distribution's trajectory includes increasingly adding more
datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936
Parameters
----------
mcmc_step_fn
A function that computes the log density of the prior distribution
mcmc_init_fn
A function that returns the probability at a given position.
resampling_fn
A random function that resamples generated particles based of weights
num_mcmc_steps
Number of iterations in the MCMC chain.
mcmc_parameters
A dictionary of parameters to be used by the inner MCMC kernels
partial_logposterior_factory:
A callable that given an array of 0 and 1, returns a function logposterior(x).
The array represents which values to include in the logposterior calculation. The logposterior
must be jax compilable.

Returns
-------
A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for
the current and previous posteriors, and takes a data-tempered SMC state.
"""
delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy)

def step(
key, state: PartialPosteriorsSMCState, data_mask: Array
) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]:
logposterior_fn = partial_logposterior_factory(data_mask)

previous_logposterior_fn = partial_logposterior_factory(state.data_mask)

def log_weights_fn(x):
return logposterior_fn(x) - previous_logposterior_fn(x)

state, info = delegate(
key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn
)

return (
PartialPosteriorsSMCState(state.particles, state.weights, data_mask),
info,
)

return step


def as_top_level_api(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: dict,
resampling_fn: Callable,
num_mcmc_steps,
partial_logposterior_factory: Callable,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""A factory that wraps the kernel into a SamplingAlgorithm object.
See build_kernel for full documentation on the parameters.
"""

kernel = build_kernel(
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
num_mcmc_steps,
mcmc_parameters,
partial_logposterior_factory,
update_strategy,
)

def init_fn(position: ArrayLikeTree, num_observations, rng_key=None):
del rng_key
return init(position, num_observations)

def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array):
return kernel(key, state, data_mask)

return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type]
66 changes: 11 additions & 55 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple, Optional

import jax
import jax.numpy as jnp

import blackjax.smc as smc
import blackjax.smc.from_mcmc as smc_from_mcmc
from blackjax.base import SamplingAlgorithm
from blackjax.smc.base import SMCState
from blackjax.smc.base import update_and_take_last
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"]
Expand Down Expand Up @@ -48,35 +48,6 @@ def init(particles: ArrayLikeTree):
return TemperedSMCState(particles, weights, 0.0)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""
Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
Expand Down Expand Up @@ -121,6 +92,9 @@ def build_kernel(
information about the transition.

"""
delegate = smc_from_mcmc.build_kernel(
mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy
)

def kernel(
rng_key: PRNGKey,
Expand Down Expand Up @@ -153,14 +127,6 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -169,23 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

smc_state, info = smc.base.step(
smc_state, info = delegate(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
state,
num_mcmc_steps,
mcmc_parameters,
tempered_logposterior_fn,
log_weights_fn,
)

tempered_state = TemperedSMCState(
Expand Down
Loading
Loading