Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New integrator, and add some metadata to integrators.py #681

Merged
merged 56 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
b60e4ca
TESTS
reubenharry May 13, 2024
0c5aa2d
TESTS
reubenharry May 13, 2024
5eeb3e1
UPDATE DOCSTRING
reubenharry May 13, 2024
4a09156
ADD STREAMING VERSION
reubenharry May 13, 2024
dfb5ee0
ADD PRECONDITIONING TO MCLMC
reubenharry May 13, 2024
2ab3365
ADD PRECONDITIONING TO TUNING FOR MCLMC
reubenharry May 13, 2024
4cc3971
UPDATE GITIGNORE
reubenharry May 13, 2024
f987da3
UPDATE GITIGNORE
reubenharry May 13, 2024
dbab9a3
UPDATE TESTS
reubenharry May 13, 2024
a7ffdb8
UPDATE TESTS
reubenharry May 13, 2024
098f5ad
UPDATE TESTS
reubenharry May 13, 2024
5bd2a3f
ADD DOCSTRING
reubenharry May 13, 2024
4fc1453
ADD TEST
reubenharry May 13, 2024
3678428
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 13, 2024
203f1fd
STREAMING AVERAGE
reubenharry May 15, 2024
fc347d6
ADD TEST
reubenharry May 15, 2024
49410f9
REFACTOR RUN_INFERENCE_ALGORITHM
reubenharry May 15, 2024
ffdca93
UPDATE DOCSTRING
reubenharry May 15, 2024
b7b7084
Precommit
reubenharry May 15, 2024
9d2601d
RESOLVE MERGE CONFLICTS
reubenharry May 15, 2024
97cfc9e
CLEAN TESTS
reubenharry May 15, 2024
45429b8
CLEAN TESTS
reubenharry May 15, 2024
dd9fb1c
Merge branch 'preconditioned_mclmc' of https://github.com/reubenharry…
reubenharry May 15, 2024
a27dba9
GITIGNORE
reubenharry May 15, 2024
7a6e42b
PRECOMMIT CLEAN UP
reubenharry May 15, 2024
2d3c3fc
FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS
reubenharry May 15, 2024
dad0060
TEMPORARILY ADD BENCHMARKS
reubenharry May 15, 2024
6bd5ab1
ADD INITIAL_POSITION
reubenharry May 17, 2024
5615261
FIX TEST
reubenharry May 17, 2024
d66a561
Merge branch 'main' into inference_algorithm
reubenharry May 17, 2024
290addc
Merge branch 'main' into inference_algorithm
reubenharry May 18, 2024
35d4880
Merge branch 'inference_algorithm' into new_integrator
reubenharry May 18, 2024
356cd3b
CLEAN UP
reubenharry May 18, 2024
67c0002
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 18, 2024
63a8042
REMOVE BENCHMARKS
reubenharry May 18, 2024
51fee69
ADD TEST
reubenharry May 18, 2024
29994d7
REMOVE BENCHMARKS
reubenharry May 18, 2024
e4be0ae
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 18, 2024
64948e5
BUG FIX
reubenharry May 18, 2024
c3d44f3
CHANGE PRECISION
reubenharry May 18, 2024
94d43bd
CHANGE PRECISION
reubenharry May 18, 2024
17b7454
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 18, 2024
636ef43
ADD OMELYAN TEST
reubenharry May 18, 2024
178b452
RENAME O
reubenharry May 19, 2024
9c1c816
Merge branch 'inference_algorithm' of github.com:reubenharry/blackjax…
reubenharry May 19, 2024
db90cdc
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 19, 2024
a26d4a0
UPDATE STREAMING AVG
reubenharry May 19, 2024
0ff1d24
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 19, 2024
4e2b7c0
MERGE
reubenharry May 20, 2024
6bacb6c
UPDATE PR
reubenharry May 24, 2024
654dacc
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 24, 2024
9c2fea7
RENAME STD_MAT
reubenharry May 24, 2024
c249a12
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 24, 2024
06dd04d
MERGE MAIN
reubenharry May 25, 2024
88b06cb
MERGE MAIN
reubenharry May 26, 2024
abe707c
REMOVE COEFFICIENT EXPORTS
reubenharry May 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python

explore.py

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
95 changes: 81 additions & 14 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@
from blackjax.types import ArrayTree

__all__ = [
"velocity_verlet_coefficients",
"mclachlan_coefficients",
"yoshida_coefficients",
"omelyan_coefficients",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove? Are you planning to use it outside of integrators?

Copy link
Contributor Author

@reubenharry reubenharry May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have found it convenient to use them outside integrators. For example, if I want to run scripts that try different integrators, and I want access to their number of gradient calls. I suppose the other option would be to have a dictionary like {"velocity_verlet": {"num_grads": ..., "name": ..., "order": ...} }. Is that what you'd recommend?

As another example, I often want to do benchmarks against different integrators, but want to just iterate over X_coefficients, and then use generate_isokinetic_integrator for MCLMC and generate_euclidean_integrator for HMC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems like code duplication to have isokinetic_velocity_verlet, velocity_verlet, etc, when we could just use the coefficients with generate_isokinetic_integrator and generate_euclidean_integrator

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems like code duplication to have isokinetic_velocity_verlet, velocity_verlet, etc, when we could just use the coefficients with generate_isokinetic_integrator and generate_euclidean_integrator

The design choice is motivated by how we are envisioning the usage of integrators in the library. Currently the design is to have integrator being a static set.

As another example, I often want to do benchmarks against different integrators, but want to just iterate over X_coefficients, and then use generate_isokinetic_integrator for MCLMC and generate_euclidean_integrator for HMC

You should iterate through the integrator objects instead

For example, if I want to run scripts that try different integrators, and I want access to their number of gradient calls. I suppose the other option would be to have a dictionary like {"velocity_verlet": {"num_grads": ..., "name": ..., "order": ...} }. Is that what you'd recommend?

Yes given that these are static, you should put them in your script as static parameters. I dont yet see those are useful in the library outside of benchmarking.

Copy link
Contributor Author

@reubenharry reubenharry May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that makes sense. I think my one point of disagreement is that if I don't expose the coefficients, nothing in the code "knows" that velocity_verlet and isokinetic_velocity_verlet are related. So I will have to have a dictionary of {"euclidean": velocity_verlet, "isokinetic": isokinetic_velocity_verlet, ...} when I want to compare each integrator on hmc vs mclmc, which I'm currently doing. This is a little painful, but not the end of the world

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, as per request

"mclachlan",
"omelyan",
"velocity_verlet",
"yoshida",
"implicit_midpoint",
"isokinetic_leapfrog",
"with_isokinetic_maruyama",
"isokinetic_velocity_verlet",
"isokinetic_mclachlan",
"isokinetic_omelyan",
"isokinetic_yoshida",
"implicit_midpoint",
"calls_per_integrator_step",
"name_integrator",
"integrator_order",
]


Expand Down Expand Up @@ -70,7 +80,7 @@

.. math:: \\frac{d}{dt}f = (O_1+O_2)f

The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}`
The velocity_verlet operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}`
by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`.

In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and
Expand Down Expand Up @@ -210,7 +220,7 @@
return IntegratorState(position, momentum, logdensity, logdensity_grad)


def generate_euclidean_integrator(cofficients):
def generate_euclidean_integrator(coefficients):
"""Generate symplectic integrator for solving a Hamiltonian system.

The resulting integrator is volume-preserve and preserves the symplectic structure
Expand All @@ -225,7 +235,7 @@
one_step = generalized_two_stage_integrator(
momentum_update_fn,
position_update_fn,
cofficients,
coefficients,
format_output_fn=format_euclidean_state_output,
)
return one_step
Expand All @@ -251,8 +261,8 @@
of the kinetic energy. We are trading accuracy in exchange, and it is not
clear whether this is the right tradeoff.
"""
velocity_verlet_cofficients = [0.5, 1.0, 0.5]
velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients)
velocity_verlet_coefficients = [0.5, 1.0, 0.5]
velocity_verlet = generate_euclidean_integrator(velocity_verlet_coefficients)

"""
Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`.
Expand All @@ -268,8 +278,8 @@
b1 = 0.1931833275037836
a1 = 0.5
b2 = 1 - 2 * b1
mclachlan_cofficients = [b1, a1, b2, a1, b1]
mclachlan = generate_euclidean_integrator(mclachlan_cofficients)
mclachlan_coefficients = [b1, a1, b2, a1, b1]
mclachlan = generate_euclidean_integrator(mclachlan_coefficients)

"""
Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical`
Expand All @@ -284,8 +294,20 @@
a1 = 0.29619504261126
b2 = 0.5 - b1
a2 = 1 - 2 * a1
yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1]
yoshida = generate_euclidean_integrator(yoshida_cofficients)
yoshida_coefficients = [b1, a1, b2, a2, b2, a1, b1]
yoshida = generate_euclidean_integrator(yoshida_coefficients)

"""11 stage Omelyan integrator [I.P. Omelyan, I.M. Mryglod and R. Folk, Comput. Phys. Commun. 151 (2003) 272.],
4MN5FV in [Takaishi, Tetsuya, and Philippe De Forcrand. "Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD." Physical Review E 73.3 (2006): 036706.]
popular in LQCD"""
b1 = 0.08398315262876693
a1 = 0.2539785108410595
b2 = 0.6822365335719091
a2 = -0.03230286765269967
b3 = 0.5 - b1 - b2
a3 = 1 - 2 * (a1 + a2)
omelyan_coefficients = [b1, a1, b2, a2, b3, a3, b3, a2, b2, a1, b1]
omelyan = generate_euclidean_integrator(omelyan_coefficients)


# Intergrators with non Euclidean updates
Expand Down Expand Up @@ -372,9 +394,54 @@
return isokinetic_integrator


isokinetic_leapfrog = generate_isokinetic_integrator(velocity_verlet_cofficients)
isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients)
isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients)
isokinetic_velocity_verlet = generate_isokinetic_integrator(
velocity_verlet_coefficients
)
isokinetic_yoshida = generate_isokinetic_integrator(yoshida_coefficients)
isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_coefficients)
isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients)


def calls_per_integrator_step(c):
if c == velocity_verlet_coefficients:
return 1
if c == mclachlan_coefficients:
return 2
if c == yoshida_coefficients:
return 3
if c == omelyan_coefficients:
return 5

Check warning on line 413 in blackjax/mcmc/integrators.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L406-L413

Added lines #L406 - L413 were not covered by tests

else:
raise Exception("No such integrator exists in blackjax")

Check warning on line 416 in blackjax/mcmc/integrators.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L416

Added line #L416 was not covered by tests


def name_integrator(c):
if c == velocity_verlet_coefficients:
return "velocity_verlet"
if c == mclachlan_coefficients:
return "mclachlan"
if c == yoshida_coefficients:
return "yoshida"
if c == omelyan_coefficients:
return "omelyan"

Check warning on line 427 in blackjax/mcmc/integrators.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L420-L427

Added lines #L420 - L427 were not covered by tests

else:
raise Exception("No such integrator exists in blackjax")

Check warning on line 430 in blackjax/mcmc/integrators.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L430

Added line #L430 was not covered by tests


def integrator_order(c):
if c == velocity_verlet_coefficients:
return 2
if c == mclachlan_coefficients:
return 2
if c == yoshida_coefficients:
return 4
if c == omelyan_coefficients:
return 4

Check warning on line 441 in blackjax/mcmc/integrators.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L434-L441

Added lines #L434 - L441 were not covered by tests

else:
raise Exception("No such integrator exists in blackjax")

Check warning on line 444 in blackjax/mcmc/integrators.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L444

Added line #L444 was not covered by tests


def partially_refresh_momentum(momentum, rng_key, step_size, L):
Expand Down
2 changes: 1 addition & 1 deletion blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def buildtree_integrate(

"""
if tree_depth == 0:
# Base case - take one leapfrog step in the direction v.
# Base case - take one velocity_verlet step in the direction v.
next_state = integrator(initial_state, direction * step_size)
new_proposal = generate_proposal(initial_energy, next_state)
is_diverging = -new_proposal.weight > divergence_threshold
Expand Down
12 changes: 8 additions & 4 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ def kinetic_energy(p, position=None):
"velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4},
"mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-4},
"yoshida": {"algorithm": integrators.yoshida, "precision": 1e-4},
"omelyan": {"algorithm": integrators.omelyan, "precision": 1e-4},
"implicit_midpoint": {
"algorithm": integrators.implicit_midpoint,
"precision": 1e-4,
},
"isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog},
"isokinetic_velocity_verlet": {"algorithm": integrators.isokinetic_velocity_verlet},
"isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan},
"isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida},
"isokinetic_omelyan": {"algorithm": integrators.isokinetic_omelyan},
}


Expand All @@ -168,6 +170,7 @@ class IntegratorTest(chex.TestCase):
"velocity_verlet",
"mclachlan",
"yoshida",
"omelyan",
"implicit_midpoint",
],
)
Expand Down Expand Up @@ -241,13 +244,13 @@ def test_esh_momentum_update(self, dims):
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)

@chex.all_variants(with_pmap=False)
def test_isokinetic_leapfrog(self):
def test_isokinetic_velocity_verlet(self):
cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]])
logdensity_fn = lambda x: stats.multivariate_normal.logpdf(
x, jnp.zeros([3]), cov
)

step = self.variant(integrators.isokinetic_leapfrog(logdensity_fn))
step = self.variant(integrators.isokinetic_velocity_verlet(logdensity_fn))

rng = jax.random.key(4263456)
key0, key1 = jax.random.split(rng, 2)
Expand Down Expand Up @@ -296,9 +299,10 @@ def test_isokinetic_leapfrog(self):
@chex.all_variants(with_pmap=False)
@parameterized.parameters(
[
"isokinetic_leapfrog",
"isokinetic_velocity_verlet",
"isokinetic_mclachlan",
"isokinetic_yoshida",
"isokinetic_omelyan",
],
)
def test_isokinetic_integrator(self, integrator_name):
Expand Down
Loading