From fbe7f7535b4343eac50a425f6d90eefa74ccaab3 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 10 Nov 2023 22:35:47 +0100 Subject: [PATCH 01/78] initial draft of mclmc --- blackjax/explore.ipynb | 219 ++++++++++++++++++++++++++++++++++++++ blackjax/explore.py | 55 ++++++++++ blackjax/mcmc/__init__.py | 2 + blackjax/mcmc/mclmc.py | 218 +++++++++++++++++++++++++++++++++++++ 4 files changed, 494 insertions(+) create mode 100644 blackjax/explore.ipynb create mode 100644 blackjax/explore.py create mode 100644 blackjax/mcmc/mclmc.py diff --git a/blackjax/explore.ipynb b/blackjax/explore.ipynb new file mode 100644 index 000000000..901f9e956 --- /dev/null +++ b/blackjax/explore.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1699217679.346915 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.scipy.stats as stats\n", + "import numpy as np\n", + "import sys\n", + "import blackjax\n", + "\n", + "observed = np.random.normal(10, 20, size=1000)\n", + "def logdensity_fn(x):\n", + " return jnp.sum(jnp.square(x))\n", + "\n", + "# Build the kernel\n", + "step_size = 1e-3\n", + "inverse_mass_matrix = jnp.array([1., 1.])\n", + "nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)\n", + "\n", + "# Initialize the state\n", + "initial_position = jnp.array([1.0,1.0])\n", + "# {\"loc\": 1., \"scale\": 2.}\n", + "state = nuts.init(initial_position)\n", + "\n", + "# Iterate\n", + "rng_key = jax.random.PRNGKey(0)\n", + "for _ in range(5):\n", + "\n", + " rng_key, nuts_key = jax.random.split(rng_key)\n", + " state, _ = nuts.step(nuts_key, state)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FOO\n" + ] + } + ], + "source": [ + "from blackjax.mcmc.mclmc import Parameters\n", + "\n", + "mclmc = blackjax.mcmc.mclmc.mclmc(logdensity_fn=logdensity_fn, d = 2, transform = lambda x: x, init_key = jax.random.PRNGKey(0), params = Parameters(1,1e-3, jnp.ones(2)))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from blackjax.base import SamplingAlgorithm\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.scipy.stats as stats\n", + "import numpy as np\n", + "import sys\n", + "import blackjax\n", + "# from blackjax.types import ArrayTree\n", + "\n", + "def logdensity_fn(x):\n", + " return jnp.sum(jnp.square(x))\n", + "\n", + "# Build the kernel\n", + "step_size = 1e-3\n", + "inverse_mass_matrix = jnp.array([1., 1.])\n", + "nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)\n", + "\n", + "flip = lambda f: lambda s, k : f(k,s)\n", + "\n", + "def run_sampling_algorithm(sampling_algorithm : SamplingAlgorithm, num_steps : int, initial_val, rng_key):\n", + " state = sampling_algorithm.init(initial_val)\n", + " keys = jax.random.split(rng_key, num_steps)\n", + " _, info = jax.lax.scan(flip(sampling_algorithm.step), state, keys)\n", + " return info\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "dot requires ndarray or scalar arguments, got at position 0.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m run_sampling_algorithm(\n\u001b[1;32m 2\u001b[0m sampling_algorithm \u001b[39m=\u001b[39m mclmc, \n\u001b[1;32m 3\u001b[0m num_steps \u001b[39m=\u001b[39m \u001b[39m4\u001b[39m, \n\u001b[1;32m 4\u001b[0m initial_val \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39marray([\u001b[39m1.0\u001b[39m,\u001b[39m1.0\u001b[39m]),\n\u001b[1;32m 5\u001b[0m rng_key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39mPRNGKey(\u001b[39m0\u001b[39m)\n\u001b[1;32m 6\u001b[0m )\n", + "Cell \u001b[0;32mIn[3], line 23\u001b[0m, in \u001b[0;36mrun_sampling_algorithm\u001b[0;34m(sampling_algorithm, num_steps, initial_val, rng_key)\u001b[0m\n\u001b[1;32m 21\u001b[0m state \u001b[39m=\u001b[39m sampling_algorithm\u001b[39m.\u001b[39minit(initial_val)\n\u001b[1;32m 22\u001b[0m keys \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(rng_key, num_steps)\n\u001b[0;32m---> 23\u001b[0m _, info \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mlax\u001b[39m.\u001b[39mscan(flip(sampling_algorithm\u001b[39m.\u001b[39mstep), state, keys)\n\u001b[1;32m 24\u001b[0m \u001b[39mreturn\u001b[39;00m info\n", + " \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[3], line 18\u001b[0m, in \u001b[0;36m\u001b[0;34m(s, k)\u001b[0m\n\u001b[1;32m 15\u001b[0m inverse_mass_matrix \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39marray([\u001b[39m1.\u001b[39m, \u001b[39m1.\u001b[39m])\n\u001b[1;32m 16\u001b[0m nuts \u001b[39m=\u001b[39m blackjax\u001b[39m.\u001b[39mnuts(logdensity_fn, step_size, inverse_mass_matrix)\n\u001b[0;32m---> 18\u001b[0m flip \u001b[39m=\u001b[39m \u001b[39mlambda\u001b[39;00m f: \u001b[39mlambda\u001b[39;00m s, k : f(k,s)\n\u001b[1;32m 20\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun_sampling_algorithm\u001b[39m(sampling_algorithm : SamplingAlgorithm, num_steps : \u001b[39mint\u001b[39m, initial_val, rng_key):\n\u001b[1;32m 21\u001b[0m state \u001b[39m=\u001b[39m sampling_algorithm\u001b[39m.\u001b[39minit(initial_val)\n", + "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:206\u001b[0m, in \u001b[0;36mmclmc.__new__..step_fn\u001b[0;34m(rng_key, state)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstep_fn\u001b[39m(rng_key: PRNGKey, state):\n\u001b[0;32m--> 206\u001b[0m \u001b[39mreturn\u001b[39;00m kernel(\n\u001b[1;32m 207\u001b[0m rng_key,\n\u001b[1;32m 208\u001b[0m state,\n\u001b[1;32m 209\u001b[0m )\n", + "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:150\u001b[0m, in \u001b[0;36mbuild_kernel..kernel\u001b[0;34m(rng_key, state)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mkernel\u001b[39m(rng_key : PRNGKey, state : MCLMCState) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[MCLMCState, MCLMCInfo]:\n\u001b[1;32m 148\u001b[0m x, u, l, g \u001b[39m=\u001b[39m state\n\u001b[0;32m--> 150\u001b[0m xx, uu, ll, gg, kinetic_change \u001b[39m=\u001b[39m move(x, u, g, rng_key, L, eps, sigma)\n\u001b[1;32m 151\u001b[0m de \u001b[39m=\u001b[39m kinetic_change \u001b[39m+\u001b[39m ll \u001b[39m-\u001b[39m l\n\u001b[1;32m 152\u001b[0m \u001b[39mreturn\u001b[39;00m MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de)\n", + "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:125\u001b[0m, in \u001b[0;36mupdate..step\u001b[0;34m(x, u, g, random_key, L, eps, sigma)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"One step of the generalized dynamics.\"\"\"\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[39m# Hamiltonian step\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# print(\"BAR 3\")\u001b[39;00m\n\u001b[0;32m--> 125\u001b[0m xx, uu, ll, gg, kinetic_change \u001b[39m=\u001b[39m hamiltonian_dynamics(x\u001b[39m=\u001b[39mx, u\u001b[39m=\u001b[39mu, g\u001b[39m=\u001b[39mg, eps\u001b[39m=\u001b[39meps, sigma \u001b[39m=\u001b[39m sigma)\n\u001b[1;32m 127\u001b[0m \u001b[39m# Langevin-like noise\u001b[39;00m\n\u001b[1;32m 128\u001b[0m nu \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39msqrt((jnp\u001b[39m.\u001b[39mexp(\u001b[39m2\u001b[39m \u001b[39m*\u001b[39m eps \u001b[39m/\u001b[39m L) \u001b[39m-\u001b[39m \u001b[39m1.\u001b[39m) \u001b[39m/\u001b[39m d)\n", + "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:164\u001b[0m, in \u001b[0;36mminimal_norm..step\u001b[0;34m(x, u, g, eps, sigma)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.\"\"\"\u001b[39;00m\n\u001b[1;32m 163\u001b[0m \u001b[39m# V T V T V\u001b[39;00m\n\u001b[0;32m--> 164\u001b[0m uu, r1 \u001b[39m=\u001b[39m V(eps \u001b[39m*\u001b[39m lambda_c, u, g \u001b[39m*\u001b[39m sigma)\n\u001b[1;32m 165\u001b[0m xx, ll, gg \u001b[39m=\u001b[39m T(eps, x, \u001b[39m0.5\u001b[39m\u001b[39m*\u001b[39muu\u001b[39m*\u001b[39msigma)\n\u001b[1;32m 166\u001b[0m uu, r2 \u001b[39m=\u001b[39m V(eps \u001b[39m*\u001b[39m (\u001b[39m1\u001b[39m \u001b[39m-\u001b[39m \u001b[39m2\u001b[39m \u001b[39m*\u001b[39m lambda_c), uu, gg \u001b[39m*\u001b[39m sigma)\n", + "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:88\u001b[0m, in \u001b[0;36mupdate_momentum..update\u001b[0;34m(eps, u, g)\u001b[0m\n\u001b[1;32m 86\u001b[0m g_norm \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39msqrt(jnp\u001b[39m.\u001b[39msum(jnp\u001b[39m.\u001b[39msquare(g)))\n\u001b[1;32m 87\u001b[0m e \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m g \u001b[39m/\u001b[39m g_norm\n\u001b[0;32m---> 88\u001b[0m ue \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mdot(u, e)\n\u001b[1;32m 89\u001b[0m delta \u001b[39m=\u001b[39m eps \u001b[39m*\u001b[39m g_norm \u001b[39m/\u001b[39m (d\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 90\u001b[0m zeta \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mexp(\u001b[39m-\u001b[39mdelta)\n", + " \u001b[0;31m[... skipping hidden 12 frame]\u001b[0m\n", + "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3079\u001b[0m, in \u001b[0;36mdot\u001b[0;34m(a, b, precision, preferred_element_type)\u001b[0m\n\u001b[1;32m 3074\u001b[0m \u001b[39m@util\u001b[39m\u001b[39m.\u001b[39m_wraps(np\u001b[39m.\u001b[39mdot, lax_description\u001b[39m=\u001b[39m_PRECISION_DOC)\n\u001b[1;32m 3075\u001b[0m \u001b[39m@partial\u001b[39m(jit, static_argnames\u001b[39m=\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mprecision\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mpreferred_element_type\u001b[39m\u001b[39m'\u001b[39m), inline\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 3076\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdot\u001b[39m(a: ArrayLike, b: ArrayLike, \u001b[39m*\u001b[39m,\n\u001b[1;32m 3077\u001b[0m precision: PrecisionLike \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 3078\u001b[0m preferred_element_type: DTypeLike \u001b[39m|\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Array:\n\u001b[0;32m-> 3079\u001b[0m util\u001b[39m.\u001b[39mcheck_arraylike(\u001b[39m\"\u001b[39m\u001b[39mdot\u001b[39m\u001b[39m\"\u001b[39m, a, b)\n\u001b[1;32m 3080\u001b[0m dtypes\u001b[39m.\u001b[39mcheck_user_dtype_supported(preferred_element_type, \u001b[39m\"\u001b[39m\u001b[39mdot\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 3081\u001b[0m a, b \u001b[39m=\u001b[39m asarray(a), asarray(b)\n", + "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/jax/_src/numpy/util.py:328\u001b[0m, in \u001b[0;36mcheck_arraylike\u001b[0;34m(fun_name, *args)\u001b[0m\n\u001b[1;32m 325\u001b[0m pos, arg \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m((i, arg) \u001b[39mfor\u001b[39;00m i, arg \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(args)\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m _arraylike(arg))\n\u001b[1;32m 327\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m requires ndarray or scalar arguments, got \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m at position \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m--> 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(msg\u001b[39m.\u001b[39mformat(fun_name, \u001b[39mtype\u001b[39m(arg), pos))\n", + "\u001b[0;31mTypeError\u001b[0m: dot requires ndarray or scalar arguments, got at position 0." + ] + } + ], + "source": [ + "run_sampling_algorithm(\n", + " sampling_algorithm = mclmc, \n", + " num_steps = 4, \n", + " initial_val = jnp.array([1.0,1.0]),\n", + " rng_key = jax.random.PRNGKey(0)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "NUTSInfo(momentum=Array([[ 0.09027783, 1.1181064 ],\n", + " [-1.7013309 , -0.5390966 ],\n", + " [ 0.770002 , 0.6343672 ],\n", + " [-2.1758103 , -1.742976 ]], dtype=float32), is_divergent=Array([False, False, False, False], dtype=bool), is_turning=Array([False, False, True, False], dtype=bool), energy=Array([-1.370843 , -0.11327934, -5.84114 , -3.32897 ], dtype=float32), trajectory_leftmost_state=IntegratorState(position=Array([[1.4206688 , 0.61661047],\n", + " [3.6223001 , 2.0243049 ],\n", + " [2.1698742 , 1.2298268 ],\n", + " [7.5688944 , 5.142932 ]], dtype=float32), momentum=Array([[ -1.4299476 , 0.10284508],\n", + " [ -5.2006354 , -2.6766868 ],\n", + " [ -0.82091624, -0.29235116],\n", + " [-10.429538 , -7.2140446 ]], dtype=float32), logdensity=Array([ 2.3985085, 17.218868 , 6.220828 , 83.73791 ], dtype=float32), logdensity_grad=Array([[ 2.8413377, 1.2332209],\n", + " [ 7.2446003, 4.0486097],\n", + " [ 4.3397484, 2.4596536],\n", + " [15.137789 , 10.285864 ]], dtype=float32)), trajectory_rightmost_state=IntegratorState(position=Array([[1.1562694 , 1.5285689 ],\n", + " [0.7418781 , 0.74051124],\n", + " [2.3045995 , 1.4024434 ],\n", + " [2.1707716 , 1.2942749 ]], dtype=float32), momentum=Array([[ 0.8258719 , 1.9807057 ],\n", + " [-1.3802809 , -0.25639924],\n", + " [ 1.3709687 , 0.99710405],\n", + " [-1.9035604 , -1.579 ]], dtype=float32), logdensity=Array([3.6734817, 1.09874 , 7.2780266, 6.387397 ], dtype=float32), logdensity_grad=Array([[2.3125389, 3.0571377],\n", + " [1.4837562, 1.4810225],\n", + " [4.609199 , 2.8048868],\n", + " [4.341543 , 2.5885499]], dtype=float32)), num_trajectory_expansions=Array([10, 10, 9, 10], dtype=int32, weak_type=True), num_integration_steps=Array([1023, 1023, 511, 1023], dtype=int32, weak_type=True), acceptance_rate=Array([0.9999976 , 0.99999094, 0.99999905, 0.99998814], dtype=float32))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\n", + "run_sampling_algorithm(\n", + " sampling_algorithm = nuts, \n", + " num_steps = 4, \n", + " initial_val = jnp.array([1.0,1.0]),\n", + " rng_key = jax.random.PRNGKey(0)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mclmc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "9a512360325173e57a9de5045b432cf87295adb93a786e8e4409a45c42185bb6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/blackjax/explore.py b/blackjax/explore.py new file mode 100644 index 000000000..bb07779d4 --- /dev/null +++ b/blackjax/explore.py @@ -0,0 +1,55 @@ +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np +import sys +import blackjax +from blackjax.base import SamplingAlgorithm +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np +import sys +import blackjax +from blackjax.mcmc.mclmc import Parameters + +def logdensity_fn(x): + return -0.5*jnp.sum(jnp.square(x-5)) + +# Build the kernel +inverse_mass_matrix = jnp.array([1.0, 1.0]) + +# Initialize the state +initial_position = jnp.array([1.0, 1.0]) + + +mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + d=2, + transform=lambda x: x, + init_key=jax.random.PRNGKey(0), + params=Parameters(0.56568545, 1.4142135, inverse_mass_matrix), +) + +# ? +# tuning() + +flip = lambda f: lambda s, k: f(k, s) + +def run_sampling_algorithm( + sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key +): + state = sampling_algorithm.init(initial_val) + keys = jax.random.split(rng_key, num_steps) + _, info = jax.lax.scan(flip(sampling_algorithm.step), state, keys) + return info + +out = run_sampling_algorithm( + sampling_algorithm=mclmc, + num_steps=10000, + initial_val=jnp.array([0.1, 0.1]), + rng_key=jax.random.PRNGKey(0), +) + +print(jnp.mean(out.transformed_x, axis=0)) + diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index ced412517..59975f5b1 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -7,6 +7,7 @@ nuts, periodic_orbital, random_walk, + mclmc ) __all__ = [ @@ -18,4 +19,5 @@ "periodic_orbital", "marginal_latent_gaussian", "random_walk", + "mclmc" ] diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py new file mode 100644 index 000000000..cf3bd8c63 --- /dev/null +++ b/blackjax/mcmc/mclmc.py @@ -0,0 +1,218 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the MCLMC Kernel""" +from typing import Callable, NamedTuple, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +import blackjax.mcmc.proposal as proposal +import blackjax.mcmc.trajectory as trajectory +from blackjax.base import SamplingAlgorithm +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] + +class Parameters(NamedTuple): + """Tunable parameters + """ + + L: float + eps: float + sigma: Array + + +# MCLMCState = integrators.IntegratorState + + +class MCLMCState(NamedTuple): + """State of the MCLMC algorithm. + + """ + + x: Array + u: Array + l: float + g: Array + +class MCLMCInfo(NamedTuple): + """Additional information on the MCLMC transition. + + This additional information can be used for debugging or computing + diagnostics. + """ + + transformed_x: Array + l: Array + de: float + +def init(x_initial : ArrayLikeTree, logdensity_fn, random_key): + + grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) + l, g = grad_nlogp(x_initial) + + u = random_unit_vector(random_key, d=x_initial.shape[0]) + + return MCLMCState(x_initial, u, l, g) + + +def random_unit_vector(random_key,d): + u = jax.random.normal(jax.random.PRNGKey(0), shape = (d, )) + u /= jnp.sqrt(jnp.sum(jnp.square(u))) + return u + + +def update_position(grad_nlogp): + + def update(eps, x, u): + xx = x + eps * u + ll, gg = grad_nlogp(xx) + return xx, ll, gg + + return update + +def update_momentum(d): + """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) + similar to the implementation: https://github.com/gregversteeg/esh_dynamics + There are no exponentials e^delta, which prevents overflows when the gradient norm is large.""" + + + def update(eps, u, g): + g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) + e = - g / g_norm + ue = jnp.dot(u, e) + delta = eps * g_norm / (d-1) + zeta = jnp.exp(-delta) + uu = e *(1-zeta)*(1+zeta + ue * (1-zeta)) + 2*zeta* u + delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1-ue)*zeta**2) + return uu/jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r + + return update + +def partially_refresh_momentum(d, sequential= True): + """Adds a small noise to u and normalizes.""" + + + def rng_sequential(u, random_key, nu): + z = nu * jax.random.normal(random_key, shape = (d, )) + + return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) + + +# def rng_parallel(u, random_key, nu): +# key, subkey = jax.random.split(random_key) +# noise = nu * jax.random.normal(subkey, shape= u.shape, dtype=u.dtype) + +# return (u + noise) / jnp.sqrt(jnp.sum(jnp.square(u + noise), axis = 1))[:, None], key + + + return rng_sequential + +def update(hamiltonian_dynamics, partially_refresh_momentum, d): + +# print("BAR 4") + def step(x, u, g, random_key, L, eps, sigma): + """One step of the generalized dynamics.""" + + # Hamiltonian step + # print("BAR 3") + xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, eps=eps, sigma = sigma) + + # Langevin-like noise + nu = jnp.sqrt((jnp.exp(2 * eps / L) - 1.) / d) + uu = partially_refresh_momentum(u= uu, random_key= random_key, nu= nu) + + return xx, uu, ll, gg, kinetic_change + + return step + +def build_kernel(grad_nlogp, d, integrator, transform, params): + + L, eps, sigma = params + + hamiltonian_step, _ = integrator(T= update_position(grad_nlogp), + V= update_momentum(d), + d= d) + # print("BAR") + move = update(hamiltonian_step, partially_refresh_momentum(d), d) + # print("BAZ") + + def kernel(rng_key : PRNGKey, state : MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: + + x, u, l, g = state + + + xx, uu, ll, gg, kinetic_change = move(x, u, g, rng_key, L, eps, sigma) + de = kinetic_change + ll - l + return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) + + return kernel + +lambda_c = 0.1931833275037836 #critical value of the lambda parameter for the minimal norm integrator + +def minimal_norm(d, T, V): + + def step(x, u, g, eps, sigma): + """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" + + # V T V T V + uu, r1 = V(eps * lambda_c, u, g * sigma) + xx, ll, gg = T(eps, x, 0.5*uu*sigma) + uu, r2 = V(eps * (1 - 2 * lambda_c), uu, gg * sigma) + xx, ll, gg = T(eps, xx, 0.5*uu*sigma) + uu, r3 = V(eps * lambda_c, uu, gg * sigma) + + #kinetic energy change + kinetic_change = (r1 + r2 + r3) * (d-1) + + return xx, uu, ll, gg, kinetic_change + + return step, 2 + + + +class mclmc: + """todo: add documentation""" + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + d : int, + transform : Callable, + params : Parameters, + init_key, + *, + integrator = minimal_norm, + ) -> SamplingAlgorithm: + + grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) + + kernel = cls.build_kernel(grad_nlogp, d, integrator, transform, params) + + def init_fn(position: ArrayLikeTree): + return cls.init(position, logdensity_fn, init_key) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + ) + + return SamplingAlgorithm(init_fn, step_fn) + From 3a23242318f0c5484fdf4d979ea29ce938663a1e Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 17:24:38 +0100 Subject: [PATCH 02/78] refactor --- blackjax/explore.ipynb | 219 ------------------------------ blackjax/mcmc/mclmc.py | 104 +++++--------- blackjax/explore.py => explore.py | 24 ++-- 3 files changed, 45 insertions(+), 302 deletions(-) delete mode 100644 blackjax/explore.ipynb rename blackjax/explore.py => explore.py (58%) diff --git a/blackjax/explore.ipynb b/blackjax/explore.ipynb deleted file mode 100644 index 901f9e956..000000000 --- a/blackjax/explore.ipynb +++ /dev/null @@ -1,219 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "I0000 00:00:1699217679.346915 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import jax.scipy.stats as stats\n", - "import numpy as np\n", - "import sys\n", - "import blackjax\n", - "\n", - "observed = np.random.normal(10, 20, size=1000)\n", - "def logdensity_fn(x):\n", - " return jnp.sum(jnp.square(x))\n", - "\n", - "# Build the kernel\n", - "step_size = 1e-3\n", - "inverse_mass_matrix = jnp.array([1., 1.])\n", - "nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)\n", - "\n", - "# Initialize the state\n", - "initial_position = jnp.array([1.0,1.0])\n", - "# {\"loc\": 1., \"scale\": 2.}\n", - "state = nuts.init(initial_position)\n", - "\n", - "# Iterate\n", - "rng_key = jax.random.PRNGKey(0)\n", - "for _ in range(5):\n", - "\n", - " rng_key, nuts_key = jax.random.split(rng_key)\n", - " state, _ = nuts.step(nuts_key, state)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "FOO\n" - ] - } - ], - "source": [ - "from blackjax.mcmc.mclmc import Parameters\n", - "\n", - "mclmc = blackjax.mcmc.mclmc.mclmc(logdensity_fn=logdensity_fn, d = 2, transform = lambda x: x, init_key = jax.random.PRNGKey(0), params = Parameters(1,1e-3, jnp.ones(2)))\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from blackjax.base import SamplingAlgorithm\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import jax.scipy.stats as stats\n", - "import numpy as np\n", - "import sys\n", - "import blackjax\n", - "# from blackjax.types import ArrayTree\n", - "\n", - "def logdensity_fn(x):\n", - " return jnp.sum(jnp.square(x))\n", - "\n", - "# Build the kernel\n", - "step_size = 1e-3\n", - "inverse_mass_matrix = jnp.array([1., 1.])\n", - "nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)\n", - "\n", - "flip = lambda f: lambda s, k : f(k,s)\n", - "\n", - "def run_sampling_algorithm(sampling_algorithm : SamplingAlgorithm, num_steps : int, initial_val, rng_key):\n", - " state = sampling_algorithm.init(initial_val)\n", - " keys = jax.random.split(rng_key, num_steps)\n", - " _, info = jax.lax.scan(flip(sampling_algorithm.step), state, keys)\n", - " return info\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "dot requires ndarray or scalar arguments, got at position 0.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m run_sampling_algorithm(\n\u001b[1;32m 2\u001b[0m sampling_algorithm \u001b[39m=\u001b[39m mclmc, \n\u001b[1;32m 3\u001b[0m num_steps \u001b[39m=\u001b[39m \u001b[39m4\u001b[39m, \n\u001b[1;32m 4\u001b[0m initial_val \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39marray([\u001b[39m1.0\u001b[39m,\u001b[39m1.0\u001b[39m]),\n\u001b[1;32m 5\u001b[0m rng_key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39mPRNGKey(\u001b[39m0\u001b[39m)\n\u001b[1;32m 6\u001b[0m )\n", - "Cell \u001b[0;32mIn[3], line 23\u001b[0m, in \u001b[0;36mrun_sampling_algorithm\u001b[0;34m(sampling_algorithm, num_steps, initial_val, rng_key)\u001b[0m\n\u001b[1;32m 21\u001b[0m state \u001b[39m=\u001b[39m sampling_algorithm\u001b[39m.\u001b[39minit(initial_val)\n\u001b[1;32m 22\u001b[0m keys \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(rng_key, num_steps)\n\u001b[0;32m---> 23\u001b[0m _, info \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mlax\u001b[39m.\u001b[39mscan(flip(sampling_algorithm\u001b[39m.\u001b[39mstep), state, keys)\n\u001b[1;32m 24\u001b[0m \u001b[39mreturn\u001b[39;00m info\n", - " \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n", - "Cell \u001b[0;32mIn[3], line 18\u001b[0m, in \u001b[0;36m\u001b[0;34m(s, k)\u001b[0m\n\u001b[1;32m 15\u001b[0m inverse_mass_matrix \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39marray([\u001b[39m1.\u001b[39m, \u001b[39m1.\u001b[39m])\n\u001b[1;32m 16\u001b[0m nuts \u001b[39m=\u001b[39m blackjax\u001b[39m.\u001b[39mnuts(logdensity_fn, step_size, inverse_mass_matrix)\n\u001b[0;32m---> 18\u001b[0m flip \u001b[39m=\u001b[39m \u001b[39mlambda\u001b[39;00m f: \u001b[39mlambda\u001b[39;00m s, k : f(k,s)\n\u001b[1;32m 20\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun_sampling_algorithm\u001b[39m(sampling_algorithm : SamplingAlgorithm, num_steps : \u001b[39mint\u001b[39m, initial_val, rng_key):\n\u001b[1;32m 21\u001b[0m state \u001b[39m=\u001b[39m sampling_algorithm\u001b[39m.\u001b[39minit(initial_val)\n", - "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:206\u001b[0m, in \u001b[0;36mmclmc.__new__..step_fn\u001b[0;34m(rng_key, state)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstep_fn\u001b[39m(rng_key: PRNGKey, state):\n\u001b[0;32m--> 206\u001b[0m \u001b[39mreturn\u001b[39;00m kernel(\n\u001b[1;32m 207\u001b[0m rng_key,\n\u001b[1;32m 208\u001b[0m state,\n\u001b[1;32m 209\u001b[0m )\n", - "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:150\u001b[0m, in \u001b[0;36mbuild_kernel..kernel\u001b[0;34m(rng_key, state)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mkernel\u001b[39m(rng_key : PRNGKey, state : MCLMCState) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[MCLMCState, MCLMCInfo]:\n\u001b[1;32m 148\u001b[0m x, u, l, g \u001b[39m=\u001b[39m state\n\u001b[0;32m--> 150\u001b[0m xx, uu, ll, gg, kinetic_change \u001b[39m=\u001b[39m move(x, u, g, rng_key, L, eps, sigma)\n\u001b[1;32m 151\u001b[0m de \u001b[39m=\u001b[39m kinetic_change \u001b[39m+\u001b[39m ll \u001b[39m-\u001b[39m l\n\u001b[1;32m 152\u001b[0m \u001b[39mreturn\u001b[39;00m MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de)\n", - "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:125\u001b[0m, in \u001b[0;36mupdate..step\u001b[0;34m(x, u, g, random_key, L, eps, sigma)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"One step of the generalized dynamics.\"\"\"\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[39m# Hamiltonian step\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# print(\"BAR 3\")\u001b[39;00m\n\u001b[0;32m--> 125\u001b[0m xx, uu, ll, gg, kinetic_change \u001b[39m=\u001b[39m hamiltonian_dynamics(x\u001b[39m=\u001b[39mx, u\u001b[39m=\u001b[39mu, g\u001b[39m=\u001b[39mg, eps\u001b[39m=\u001b[39meps, sigma \u001b[39m=\u001b[39m sigma)\n\u001b[1;32m 127\u001b[0m \u001b[39m# Langevin-like noise\u001b[39;00m\n\u001b[1;32m 128\u001b[0m nu \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39msqrt((jnp\u001b[39m.\u001b[39mexp(\u001b[39m2\u001b[39m \u001b[39m*\u001b[39m eps \u001b[39m/\u001b[39m L) \u001b[39m-\u001b[39m \u001b[39m1.\u001b[39m) \u001b[39m/\u001b[39m d)\n", - "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:164\u001b[0m, in \u001b[0;36mminimal_norm..step\u001b[0;34m(x, u, g, eps, sigma)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.\"\"\"\u001b[39;00m\n\u001b[1;32m 163\u001b[0m \u001b[39m# V T V T V\u001b[39;00m\n\u001b[0;32m--> 164\u001b[0m uu, r1 \u001b[39m=\u001b[39m V(eps \u001b[39m*\u001b[39m lambda_c, u, g \u001b[39m*\u001b[39m sigma)\n\u001b[1;32m 165\u001b[0m xx, ll, gg \u001b[39m=\u001b[39m T(eps, x, \u001b[39m0.5\u001b[39m\u001b[39m*\u001b[39muu\u001b[39m*\u001b[39msigma)\n\u001b[1;32m 166\u001b[0m uu, r2 \u001b[39m=\u001b[39m V(eps \u001b[39m*\u001b[39m (\u001b[39m1\u001b[39m \u001b[39m-\u001b[39m \u001b[39m2\u001b[39m \u001b[39m*\u001b[39m lambda_c), uu, gg \u001b[39m*\u001b[39m sigma)\n", - "File \u001b[0;32m~/Library/CloudStorage/Dropbox/Reuben/Work/Berkeley/blackjax/blackjax/mcmc/mclmc.py:88\u001b[0m, in \u001b[0;36mupdate_momentum..update\u001b[0;34m(eps, u, g)\u001b[0m\n\u001b[1;32m 86\u001b[0m g_norm \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39msqrt(jnp\u001b[39m.\u001b[39msum(jnp\u001b[39m.\u001b[39msquare(g)))\n\u001b[1;32m 87\u001b[0m e \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m g \u001b[39m/\u001b[39m g_norm\n\u001b[0;32m---> 88\u001b[0m ue \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mdot(u, e)\n\u001b[1;32m 89\u001b[0m delta \u001b[39m=\u001b[39m eps \u001b[39m*\u001b[39m g_norm \u001b[39m/\u001b[39m (d\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 90\u001b[0m zeta \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mexp(\u001b[39m-\u001b[39mdelta)\n", - " \u001b[0;31m[... skipping hidden 12 frame]\u001b[0m\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3079\u001b[0m, in \u001b[0;36mdot\u001b[0;34m(a, b, precision, preferred_element_type)\u001b[0m\n\u001b[1;32m 3074\u001b[0m \u001b[39m@util\u001b[39m\u001b[39m.\u001b[39m_wraps(np\u001b[39m.\u001b[39mdot, lax_description\u001b[39m=\u001b[39m_PRECISION_DOC)\n\u001b[1;32m 3075\u001b[0m \u001b[39m@partial\u001b[39m(jit, static_argnames\u001b[39m=\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mprecision\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mpreferred_element_type\u001b[39m\u001b[39m'\u001b[39m), inline\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 3076\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdot\u001b[39m(a: ArrayLike, b: ArrayLike, \u001b[39m*\u001b[39m,\n\u001b[1;32m 3077\u001b[0m precision: PrecisionLike \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 3078\u001b[0m preferred_element_type: DTypeLike \u001b[39m|\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Array:\n\u001b[0;32m-> 3079\u001b[0m util\u001b[39m.\u001b[39mcheck_arraylike(\u001b[39m\"\u001b[39m\u001b[39mdot\u001b[39m\u001b[39m\"\u001b[39m, a, b)\n\u001b[1;32m 3080\u001b[0m dtypes\u001b[39m.\u001b[39mcheck_user_dtype_supported(preferred_element_type, \u001b[39m\"\u001b[39m\u001b[39mdot\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 3081\u001b[0m a, b \u001b[39m=\u001b[39m asarray(a), asarray(b)\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/jax/_src/numpy/util.py:328\u001b[0m, in \u001b[0;36mcheck_arraylike\u001b[0;34m(fun_name, *args)\u001b[0m\n\u001b[1;32m 325\u001b[0m pos, arg \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m((i, arg) \u001b[39mfor\u001b[39;00m i, arg \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(args)\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m _arraylike(arg))\n\u001b[1;32m 327\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m requires ndarray or scalar arguments, got \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m at position \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m--> 328\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(msg\u001b[39m.\u001b[39mformat(fun_name, \u001b[39mtype\u001b[39m(arg), pos))\n", - "\u001b[0;31mTypeError\u001b[0m: dot requires ndarray or scalar arguments, got at position 0." - ] - } - ], - "source": [ - "run_sampling_algorithm(\n", - " sampling_algorithm = mclmc, \n", - " num_steps = 4, \n", - " initial_val = jnp.array([1.0,1.0]),\n", - " rng_key = jax.random.PRNGKey(0)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "NUTSInfo(momentum=Array([[ 0.09027783, 1.1181064 ],\n", - " [-1.7013309 , -0.5390966 ],\n", - " [ 0.770002 , 0.6343672 ],\n", - " [-2.1758103 , -1.742976 ]], dtype=float32), is_divergent=Array([False, False, False, False], dtype=bool), is_turning=Array([False, False, True, False], dtype=bool), energy=Array([-1.370843 , -0.11327934, -5.84114 , -3.32897 ], dtype=float32), trajectory_leftmost_state=IntegratorState(position=Array([[1.4206688 , 0.61661047],\n", - " [3.6223001 , 2.0243049 ],\n", - " [2.1698742 , 1.2298268 ],\n", - " [7.5688944 , 5.142932 ]], dtype=float32), momentum=Array([[ -1.4299476 , 0.10284508],\n", - " [ -5.2006354 , -2.6766868 ],\n", - " [ -0.82091624, -0.29235116],\n", - " [-10.429538 , -7.2140446 ]], dtype=float32), logdensity=Array([ 2.3985085, 17.218868 , 6.220828 , 83.73791 ], dtype=float32), logdensity_grad=Array([[ 2.8413377, 1.2332209],\n", - " [ 7.2446003, 4.0486097],\n", - " [ 4.3397484, 2.4596536],\n", - " [15.137789 , 10.285864 ]], dtype=float32)), trajectory_rightmost_state=IntegratorState(position=Array([[1.1562694 , 1.5285689 ],\n", - " [0.7418781 , 0.74051124],\n", - " [2.3045995 , 1.4024434 ],\n", - " [2.1707716 , 1.2942749 ]], dtype=float32), momentum=Array([[ 0.8258719 , 1.9807057 ],\n", - " [-1.3802809 , -0.25639924],\n", - " [ 1.3709687 , 0.99710405],\n", - " [-1.9035604 , -1.579 ]], dtype=float32), logdensity=Array([3.6734817, 1.09874 , 7.2780266, 6.387397 ], dtype=float32), logdensity_grad=Array([[2.3125389, 3.0571377],\n", - " [1.4837562, 1.4810225],\n", - " [4.609199 , 2.8048868],\n", - " [4.341543 , 2.5885499]], dtype=float32)), num_trajectory_expansions=Array([10, 10, 9, 10], dtype=int32, weak_type=True), num_integration_steps=Array([1023, 1023, 511, 1023], dtype=int32, weak_type=True), acceptance_rate=Array([0.9999976 , 0.99999094, 0.99999905, 0.99998814], dtype=float32))" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "\n", - "run_sampling_algorithm(\n", - " sampling_algorithm = nuts, \n", - " num_steps = 4, \n", - " initial_val = jnp.array([1.0,1.0]),\n", - " rng_key = jax.random.PRNGKey(0)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mclmc", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "9a512360325173e57a9de5045b432cf87295adb93a786e8e4409a45c42185bb6" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index cf3bd8c63..858ae91ba 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -24,30 +24,20 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] +__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc", "Parameters"] class Parameters(NamedTuple): """Tunable parameters """ L: float - eps: float - sigma: Array + step_size: float + inverse_mass_matrix: Array -# MCLMCState = integrators.IntegratorState +MCLMCState = integrators.IntegratorState -class MCLMCState(NamedTuple): - """State of the MCLMC algorithm. - - """ - - x: Array - u: Array - l: float - g: Array - class MCLMCInfo(NamedTuple): """Additional information on the MCLMC transition. @@ -61,8 +51,8 @@ class MCLMCInfo(NamedTuple): def init(x_initial : ArrayLikeTree, logdensity_fn, random_key): - grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) - l, g = grad_nlogp(x_initial) + grad_logp = jax.value_and_grad(logdensity_fn) + l, g = grad_logp(x_initial) u = random_unit_vector(random_key, d=x_initial.shape[0]) @@ -75,11 +65,11 @@ def random_unit_vector(random_key,d): return u -def update_position(grad_nlogp): +def update_position(grad_logp): - def update(eps, x, u): - xx = x + eps * u - ll, gg = grad_nlogp(xx) + def update(step_size, x, u): + xx = x + step_size * u + ll, gg = grad_logp(xx) return xx, ll, gg return update @@ -90,11 +80,11 @@ def update_momentum(d): There are no exponentials e^delta, which prevents overflows when the gradient norm is large.""" - def update(eps, u, g): + def update(step_size, u, g): g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) - e = - g / g_norm + e = g / g_norm ue = jnp.dot(u, e) - delta = eps * g_norm / (d-1) + delta = step_size * g_norm / (d-1) zeta = jnp.exp(-delta) uu = e *(1-zeta)*(1+zeta + ue * (1-zeta)) + 2*zeta* u delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1-ue)*zeta**2) @@ -102,60 +92,47 @@ def update(eps, u, g): return update -def partially_refresh_momentum(d, sequential= True): - """Adds a small noise to u and normalizes.""" - - def rng_sequential(u, random_key, nu): - z = nu * jax.random.normal(random_key, shape = (d, )) +def partially_refresh_momentum(u, random_key, nu): + """Adds a small noise to u and normalizes.""" + z = nu * jax.random.normal(random_key, shape = (u.shape[0], )) return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - -# def rng_parallel(u, random_key, nu): -# key, subkey = jax.random.split(random_key) -# noise = nu * jax.random.normal(subkey, shape= u.shape, dtype=u.dtype) -# return (u + noise) / jnp.sqrt(jnp.sum(jnp.square(u + noise), axis = 1))[:, None], key - return rng_sequential - def update(hamiltonian_dynamics, partially_refresh_momentum, d): -# print("BAR 4") - def step(x, u, g, random_key, L, eps, sigma): + def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): """One step of the generalized dynamics.""" # Hamiltonian step - # print("BAR 3") - xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, eps=eps, sigma = sigma) + xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, step_size=step_size, inverse_mass_matrix = inverse_mass_matrix) # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * eps / L) - 1.) / d) + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.) / d) uu = partially_refresh_momentum(u= uu, random_key= random_key, nu= nu) return xx, uu, ll, gg, kinetic_change return step -def build_kernel(grad_nlogp, d, integrator, transform, params): +def build_kernel(grad_logp, d, integrator, transform): - L, eps, sigma = params - hamiltonian_step, _ = integrator(T= update_position(grad_nlogp), + hamiltonian_step, _ = integrator(T= update_position(grad_logp), V= update_momentum(d), d= d) - # print("BAR") - move = update(hamiltonian_step, partially_refresh_momentum(d), d) - # print("BAZ") + move = update(hamiltonian_step, partially_refresh_momentum, d) - def kernel(rng_key : PRNGKey, state : MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: + def kernel(rng_key : PRNGKey, state : MCLMCState, params : Parameters) -> tuple[MCLMCState, MCLMCInfo]: x, u, l, g = state + L, step_size, inverse_mass_matrix = params - xx, uu, ll, gg, kinetic_change = move(x, u, g, rng_key, L, eps, sigma) + xx, uu, ll, gg, kinetic_change = move(x, u, g, rng_key, L, step_size, inverse_mass_matrix) de = kinetic_change + ll - l return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) @@ -165,15 +142,15 @@ def kernel(rng_key : PRNGKey, state : MCLMCState) -> tuple[MCLMCState, MCLMCInfo def minimal_norm(d, T, V): - def step(x, u, g, eps, sigma): + def step(x, u, g, step_size, inverse_mass_matrix): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" # V T V T V - uu, r1 = V(eps * lambda_c, u, g * sigma) - xx, ll, gg = T(eps, x, 0.5*uu*sigma) - uu, r2 = V(eps * (1 - 2 * lambda_c), uu, gg * sigma) - xx, ll, gg = T(eps, xx, 0.5*uu*sigma) - uu, r3 = V(eps * lambda_c, uu, gg * sigma) + uu, r1 = V(step_size * lambda_c, u, g * inverse_mass_matrix) + xx, ll, gg = T(step_size, x, 0.5*uu*inverse_mass_matrix) + uu, r2 = V(step_size * (1 - 2 * lambda_c), uu, gg * inverse_mass_matrix) + xx, ll, gg = T(step_size, xx, 0.5*uu*inverse_mass_matrix) + uu, r3 = V(step_size * lambda_c, uu, gg * inverse_mass_matrix) #kinetic energy change kinetic_change = (r1 + r2 + r3) * (d-1) @@ -195,24 +172,17 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, d : int, transform : Callable, - params : Parameters, - init_key, - *, integrator = minimal_norm, ) -> SamplingAlgorithm: - grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) + grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_nlogp, d, integrator, transform, params) + kernel = cls.build_kernel(grad_logp, d, integrator, transform) - def init_fn(position: ArrayLikeTree): - return cls.init(position, logdensity_fn, init_key) + def init_fn(position: ArrayLikeTree, rng_key): + return cls.init(position, logdensity_fn, rng_key) - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - ) + - return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, kernel) diff --git a/blackjax/explore.py b/explore.py similarity index 58% rename from blackjax/explore.py rename to explore.py index bb07779d4..d4ff77d70 100644 --- a/blackjax/explore.py +++ b/explore.py @@ -1,23 +1,15 @@ import jax import jax.numpy as jnp -import jax.scipy.stats as stats -import numpy as np -import sys import blackjax from blackjax.base import SamplingAlgorithm import jax import jax.numpy as jnp -import jax.scipy.stats as stats -import numpy as np -import sys import blackjax from blackjax.mcmc.mclmc import Parameters def logdensity_fn(x): return -0.5*jnp.sum(jnp.square(x-5)) -# Build the kernel -inverse_mass_matrix = jnp.array([1.0, 1.0]) # Initialize the state initial_position = jnp.array([1.0, 1.0]) @@ -26,22 +18,21 @@ def logdensity_fn(x): mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, d=2, - transform=lambda x: x, - init_key=jax.random.PRNGKey(0), - params=Parameters(0.56568545, 1.4142135, inverse_mass_matrix), + transform=lambda x: x ) +params=Parameters(L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0])) + # ? # tuning() -flip = lambda f: lambda s, k: f(k, s) - def run_sampling_algorithm( sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key ): - state = sampling_algorithm.init(initial_val) - keys = jax.random.split(rng_key, num_steps) - _, info = jax.lax.scan(flip(sampling_algorithm.step), state, keys) + + keys = jax.random.split(rng_key, num_steps+1) + state = sampling_algorithm.init(initial_val, keys[0]) + _, info = jax.lax.scan(lambda s, k: (sampling_algorithm.step(k,s, params=params)), state, keys[1:]) return info out = run_sampling_algorithm( @@ -53,3 +44,4 @@ def run_sampling_algorithm( print(jnp.mean(out.transformed_x, axis=0)) +assert(jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.004714, 5.018204])) \ No newline at end of file From 86b3a903870f830d2e6e61f0765eea6ac24dd0c2 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 18:05:56 +0100 Subject: [PATCH 03/78] wip --- blackjax/mcmc/mclmc.py | 171 ++++++++++++++++++++--------------------- explore.py | 22 +++--- 2 files changed, 95 insertions(+), 98 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 858ae91ba..8e4e5f846 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the MCLMC Kernel""" -from typing import Callable, NamedTuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp import blackjax.mcmc.integrators as integrators -import blackjax.mcmc.metrics as metrics -import blackjax.mcmc.proposal as proposal -import blackjax.mcmc.trajectory as trajectory from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc", "Parameters"] + class Parameters(NamedTuple): - """Tunable parameters - """ + """Tunable parameters""" L: float step_size: float @@ -49,116 +46,116 @@ class MCLMCInfo(NamedTuple): l: Array de: float -def init(x_initial : ArrayLikeTree, logdensity_fn, random_key): - grad_logp = jax.value_and_grad(logdensity_fn) - l, g = grad_logp(x_initial) +def init(x_initial: ArrayTree, logdensity_fn, random_key): + grad_logp = jax.value_and_grad(logdensity_fn) + l, g = grad_logp(x_initial) - u = random_unit_vector(random_key, d=x_initial.shape[0]) + u = random_unit_vector(random_key, d=x_initial.shape[0]) - return MCLMCState(x_initial, u, l, g) + return MCLMCState(x_initial, u, l, g) -def random_unit_vector(random_key,d): - u = jax.random.normal(jax.random.PRNGKey(0), shape = (d, )) +def random_unit_vector(random_key, d): + u = jax.random.normal(random_key, shape=(d,)) u /= jnp.sqrt(jnp.sum(jnp.square(u))) return u def update_position(grad_logp): - - def update(step_size, x, u): - xx = x + step_size * u - ll, gg = grad_logp(xx) - return xx, ll, gg - - return update - -def update_momentum(d): - """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) - similar to the implementation: https://github.com/gregversteeg/esh_dynamics - There are no exponentials e^delta, which prevents overflows when the gradient norm is large.""" - - - def update(step_size, u, g): - g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) - e = g / g_norm - ue = jnp.dot(u, e) - delta = step_size * g_norm / (d-1) - zeta = jnp.exp(-delta) - uu = e *(1-zeta)*(1+zeta + ue * (1-zeta)) + 2*zeta* u - delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1-ue)*zeta**2) - return uu/jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r - - return update - - + def update(step_size, x, u): + xx = x + step_size * u + ll, gg = grad_logp(xx) + return xx, ll, gg + + return update + + + +def update_momentum(step_size, u, g): + """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) + similar to the implementation: https://github.com/gregversteeg/esh_dynamics + There are no exponentials e^delta, which prevents overflows when the gradient norm is large. + """ + g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) + e = g / g_norm + ue = jnp.dot(u, e) + dim = u.shape[0] + delta = step_size * g_norm / (dim - 1) + zeta = jnp.exp(-delta) + uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u + delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1 - ue) * zeta**2) + return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r + + + def partially_refresh_momentum(u, random_key, nu): """Adds a small noise to u and normalizes.""" - z = nu * jax.random.normal(random_key, shape = (u.shape[0], )) + z = nu * jax.random.normal(random_key, shape=(u.shape[0],)) return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - - def update(hamiltonian_dynamics, partially_refresh_momentum, d): - - def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): - """One step of the generalized dynamics.""" + def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): + """One step of the generalized dynamics.""" + + # Hamiltonian step + xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics( + x=x, u=u, g=g, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix + ) - # Hamiltonian step - xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, step_size=step_size, inverse_mass_matrix = inverse_mass_matrix) + # Langevin-like noise + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / d) + uu = partially_refresh_momentum(u=uu, random_key=random_key, nu=nu) - # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.) / d) - uu = partially_refresh_momentum(u= uu, random_key= random_key, nu= nu) + return xx, uu, ll, gg, kinetic_change - return xx, uu, ll, gg, kinetic_change + return step - return step def build_kernel(grad_logp, d, integrator, transform): + hamiltonian_step = integrator( + T=update_position(grad_logp), V=update_momentum, d=d + ) + move = update(hamiltonian_step, partially_refresh_momentum, d) + def kernel( + rng_key: PRNGKey, state: MCLMCState, params: Parameters + ) -> tuple[MCLMCState, MCLMCInfo]: + x, u, l, g = state - hamiltonian_step, _ = integrator(T= update_position(grad_logp), - V= update_momentum(d), - d= d) - move = update(hamiltonian_step, partially_refresh_momentum, d) - - def kernel(rng_key : PRNGKey, state : MCLMCState, params : Parameters) -> tuple[MCLMCState, MCLMCInfo]: + L, step_size, inverse_mass_matrix = params - x, u, l, g = state + xx, uu, ll, gg, kinetic_change = move( + x, u, g, rng_key, L, step_size, inverse_mass_matrix + ) + de = kinetic_change + ll - l + return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) - L, step_size, inverse_mass_matrix = params - - xx, uu, ll, gg, kinetic_change = move(x, u, g, rng_key, L, step_size, inverse_mass_matrix) - de = kinetic_change + ll - l - return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) + return kernel - return kernel -lambda_c = 0.1931833275037836 #critical value of the lambda parameter for the minimal norm integrator +lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator -def minimal_norm(d, T, V): - def step(x, u, g, step_size, inverse_mass_matrix): - """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" +def minimal_norm(d, T, V): + def step(x, u, g, step_size, inverse_mass_matrix): + """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" - # V T V T V - uu, r1 = V(step_size * lambda_c, u, g * inverse_mass_matrix) - xx, ll, gg = T(step_size, x, 0.5*uu*inverse_mass_matrix) - uu, r2 = V(step_size * (1 - 2 * lambda_c), uu, gg * inverse_mass_matrix) - xx, ll, gg = T(step_size, xx, 0.5*uu*inverse_mass_matrix) - uu, r3 = V(step_size * lambda_c, uu, gg * inverse_mass_matrix) + # V T V T V + uu, r1 = V(step_size * lambda_c, u, g * inverse_mass_matrix) + xx, ll, gg = T(step_size, x, 0.5 * uu * inverse_mass_matrix) + uu, r2 = V(step_size * (1 - 2 * lambda_c), uu, gg * inverse_mass_matrix) + xx, ll, gg = T(step_size, xx, 0.5 * uu * inverse_mass_matrix) + uu, r3 = V(step_size * lambda_c, uu, gg * inverse_mass_matrix) - #kinetic energy change - kinetic_change = (r1 + r2 + r3) * (d-1) + # kinetic energy change + kinetic_change = (r1 + r2 + r3) * (d - 1) - return xx, uu, ll, gg, kinetic_change - - return step, 2 + return xx, uu, ll, gg, kinetic_change + return step class mclmc: @@ -170,19 +167,15 @@ class mclmc: def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - d : int, - transform : Callable, - integrator = minimal_norm, + d: int, + transform: Callable, + integrator=minimal_norm, ) -> SamplingAlgorithm: - grad_logp = jax.value_and_grad(logdensity_fn) kernel = cls.build_kernel(grad_logp, d, integrator, transform) - def init_fn(position: ArrayLikeTree, rng_key): + def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): return cls.init(position, logdensity_fn, rng_key) - - return SamplingAlgorithm(init_fn, kernel) - diff --git a/explore.py b/explore.py index d4ff77d70..6176d8c8d 100644 --- a/explore.py +++ b/explore.py @@ -7,8 +7,9 @@ import blackjax from blackjax.mcmc.mclmc import Parameters + def logdensity_fn(x): - return -0.5*jnp.sum(jnp.square(x-5)) + return -0.5 * jnp.sum(jnp.square(x - 5)) # Initialize the state @@ -16,25 +17,28 @@ def logdensity_fn(x): mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - d=2, - transform=lambda x: x + logdensity_fn=logdensity_fn, d=2, transform=lambda x: x ) -params=Parameters(L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0])) +params = Parameters( + L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0]) +) # ? # tuning() + def run_sampling_algorithm( sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key ): - - keys = jax.random.split(rng_key, num_steps+1) + keys = jax.random.split(rng_key, num_steps + 1) state = sampling_algorithm.init(initial_val, keys[0]) - _, info = jax.lax.scan(lambda s, k: (sampling_algorithm.step(k,s, params=params)), state, keys[1:]) + _, info = jax.lax.scan( + lambda s, k: (sampling_algorithm.step(k, s, params=params)), state, keys[1:] + ) return info + out = run_sampling_algorithm( sampling_algorithm=mclmc, num_steps=10000, @@ -44,4 +48,4 @@ def run_sampling_algorithm( print(jnp.mean(out.transformed_x, axis=0)) -assert(jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.004714, 5.018204])) \ No newline at end of file +assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.0048037, 5.0181437]) From e82550f3db0fb9e3cbc8a7beb803410ecd939a89 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 18:07:50 +0100 Subject: [PATCH 04/78] wip --- blackjax/mcmc/mclmc.py | 24 ++++++++++++------------ explore.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 8e4e5f846..fd99c4c68 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -51,13 +51,13 @@ def init(x_initial: ArrayTree, logdensity_fn, random_key): grad_logp = jax.value_and_grad(logdensity_fn) l, g = grad_logp(x_initial) - u = random_unit_vector(random_key, d=x_initial.shape[0]) + u = random_unit_vector(random_key, dim=x_initial.shape[0]) return MCLMCState(x_initial, u, l, g) -def random_unit_vector(random_key, d): - u = jax.random.normal(random_key, shape=(d,)) +def random_unit_vector(random_key, dim): + u = jax.random.normal(random_key, shape=(dim,)) u /= jnp.sqrt(jnp.sum(jnp.square(u))) return u @@ -96,7 +96,7 @@ def partially_refresh_momentum(u, random_key, nu): return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) -def update(hamiltonian_dynamics, partially_refresh_momentum, d): +def update(hamiltonian_dynamics, partially_refresh_momentum, dim): def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): """One step of the generalized dynamics.""" @@ -106,7 +106,7 @@ def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): ) # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / d) + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) uu = partially_refresh_momentum(u=uu, random_key=random_key, nu=nu) return xx, uu, ll, gg, kinetic_change @@ -114,11 +114,11 @@ def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): return step -def build_kernel(grad_logp, d, integrator, transform): +def build_kernel(grad_logp, dim, integrator, transform): hamiltonian_step = integrator( - T=update_position(grad_logp), V=update_momentum, d=d + T=update_position(grad_logp), V=update_momentum, dim=dim ) - move = update(hamiltonian_step, partially_refresh_momentum, d) + move = update(hamiltonian_step, partially_refresh_momentum, dim) def kernel( rng_key: PRNGKey, state: MCLMCState, params: Parameters @@ -139,7 +139,7 @@ def kernel( lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator -def minimal_norm(d, T, V): +def minimal_norm(dim, T, V): def step(x, u, g, step_size, inverse_mass_matrix): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" @@ -151,7 +151,7 @@ def step(x, u, g, step_size, inverse_mass_matrix): uu, r3 = V(step_size * lambda_c, uu, gg * inverse_mass_matrix) # kinetic energy change - kinetic_change = (r1 + r2 + r3) * (d - 1) + kinetic_change = (r1 + r2 + r3) * (dim - 1) return xx, uu, ll, gg, kinetic_change @@ -167,13 +167,13 @@ class mclmc: def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - d: int, + dim: int, transform: Callable, integrator=minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, d, integrator, transform) + kernel = cls.build_kernel(grad_logp, dim, integrator, transform) def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): return cls.init(position, logdensity_fn, rng_key) diff --git a/explore.py b/explore.py index 6176d8c8d..b7be6addb 100644 --- a/explore.py +++ b/explore.py @@ -17,7 +17,7 @@ def logdensity_fn(x): mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, d=2, transform=lambda x: x + logdensity_fn=logdensity_fn, dim=2, transform=lambda x: x ) params = Parameters( From f0e1bec9042b856b347c11f0a17ba979087df0a6 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 18:13:23 +0100 Subject: [PATCH 05/78] wip --- blackjax/mcmc/mclmc.py | 61 +++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index fd99c4c68..63004355d 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -31,10 +31,8 @@ class Parameters(NamedTuple): step_size: float inverse_mass_matrix: Array - MCLMCState = integrators.IntegratorState - class MCLMCInfo(NamedTuple): """Additional information on the MCLMC transition. @@ -47,14 +45,9 @@ class MCLMCInfo(NamedTuple): de: float -def init(x_initial: ArrayTree, logdensity_fn, random_key): - grad_logp = jax.value_and_grad(logdensity_fn) - l, g = grad_logp(x_initial) - - u = random_unit_vector(random_key, dim=x_initial.shape[0]) - - return MCLMCState(x_initial, u, l, g) - +### +# helper funcs +### def random_unit_vector(random_key, dim): u = jax.random.normal(random_key, shape=(dim,)) @@ -70,7 +63,16 @@ def update(step_size, x, u): return update +def partially_refresh_momentum(u, random_key, nu): + """Adds a small noise to u and normalizes.""" + z = nu * jax.random.normal(random_key, shape=(u.shape[0],)) + + return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) + +### +# integrator +### def update_momentum(step_size, u, g): """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) @@ -88,37 +90,19 @@ def update_momentum(step_size, u, g): return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r +def init(x_initial: ArrayTree, logdensity_fn, random_key): + grad_logp = jax.value_and_grad(logdensity_fn) + l, g = grad_logp(x_initial) -def partially_refresh_momentum(u, random_key, nu): - """Adds a small noise to u and normalizes.""" - z = nu * jax.random.normal(random_key, shape=(u.shape[0],)) - - return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - - -def update(hamiltonian_dynamics, partially_refresh_momentum, dim): - def step(x, u, g, random_key, L, step_size, inverse_mass_matrix): - """One step of the generalized dynamics.""" - - # Hamiltonian step - xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics( - x=x, u=u, g=g, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix - ) - - # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) - uu = partially_refresh_momentum(u=uu, random_key=random_key, nu=nu) - - return xx, uu, ll, gg, kinetic_change + u = random_unit_vector(random_key, dim=x_initial.shape[0]) - return step + return MCLMCState(x_initial, u, l, g) def build_kernel(grad_logp, dim, integrator, transform): - hamiltonian_step = integrator( + step = integrator( T=update_position(grad_logp), V=update_momentum, dim=dim ) - move = update(hamiltonian_step, partially_refresh_momentum, dim) def kernel( rng_key: PRNGKey, state: MCLMCState, params: Parameters @@ -127,9 +111,14 @@ def kernel( L, step_size, inverse_mass_matrix = params - xx, uu, ll, gg, kinetic_change = move( - x, u, g, rng_key, L, step_size, inverse_mass_matrix + xx, uu, ll, gg, kinetic_change = step( + x=x, u=u, g=g, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix ) + + # Langevin-like noise + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + uu = partially_refresh_momentum(u=uu, random_key=rng_key, nu=nu) + de = kinetic_change + ll - l return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) From 4d7dc572c0bece87f135ea475d9560cdaec76a17 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 18:34:49 +0100 Subject: [PATCH 06/78] wip --- blackjax/diagnostics.py | 2 +- blackjax/mcmc/__init__.py | 4 +- blackjax/mcmc/mclmc.py | 78 +++++++++++++++++++-------------------- explore.py | 56 +++++++++++++++++++++++----- 4 files changed, 86 insertions(+), 54 deletions(-) diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index da861d9b1..ed7ad5bbc 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -206,4 +206,4 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): tau_hat = jnp.maximum(tau_hat, 1 / np.log10(ess_raw)) ess = ess_raw / tau_hat - return ess.squeeze() + return ess.squeeze() \ No newline at end of file diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 59975f5b1..ce9a149b3 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -4,10 +4,10 @@ hmc, mala, marginal_latent_gaussian, + mclmc, nuts, periodic_orbital, random_walk, - mclmc ) __all__ = [ @@ -19,5 +19,5 @@ "periodic_orbital", "marginal_latent_gaussian", "random_walk", - "mclmc" + "mclmc", ] diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 63004355d..ac8001049 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -19,7 +19,7 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import Array, ArrayLike, PRNGKey __all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc", "Parameters"] @@ -31,24 +31,23 @@ class Parameters(NamedTuple): step_size: float inverse_mass_matrix: Array + MCLMCState = integrators.IntegratorState -class MCLMCInfo(NamedTuple): - """Additional information on the MCLMC transition. - This additional information can be used for debugging or computing - diagnostics. - """ +class MCLMCInfo(NamedTuple): + """Additional information on the MCLMC transition.""" transformed_x: Array - l: Array - de: float + logdensity: Array + dE: float ### # helper funcs ### + def random_unit_vector(random_key, dim): u = jax.random.normal(random_key, shape=(dim,)) u /= jnp.sqrt(jnp.sum(jnp.square(u))) @@ -63,10 +62,10 @@ def update(step_size, x, u): return update + def partially_refresh_momentum(u, random_key, nu): """Adds a small noise to u and normalizes.""" z = nu * jax.random.normal(random_key, shape=(u.shape[0],)) - return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) @@ -74,6 +73,7 @@ def partially_refresh_momentum(u, random_key, nu): # integrator ### + def update_momentum(step_size, u, g): """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) similar to the implementation: https://github.com/gregversteeg/esh_dynamics @@ -90,37 +90,30 @@ def update_momentum(step_size, u, g): return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r -def init(x_initial: ArrayTree, logdensity_fn, random_key): - grad_logp = jax.value_and_grad(logdensity_fn) - l, g = grad_logp(x_initial) - - u = random_unit_vector(random_key, dim=x_initial.shape[0]) - - return MCLMCState(x_initial, u, l, g) - - -def build_kernel(grad_logp, dim, integrator, transform): - step = integrator( - T=update_position(grad_logp), V=update_momentum, dim=dim +def init(x_initial: ArrayLike, logdensity_fn, random_key): + l, g = jax.value_and_grad(logdensity_fn)(x_initial) + return MCLMCState( + position=x_initial, + momentum=random_unit_vector(random_key, dim=x_initial.shape[0]), + logdensity=l, + logdensity_grad=g, ) - def kernel( - rng_key: PRNGKey, state: MCLMCState, params: Parameters - ) -> tuple[MCLMCState, MCLMCInfo]: - x, u, l, g = state - L, step_size, inverse_mass_matrix = params - - xx, uu, ll, gg, kinetic_change = step( - x=x, u=u, g=g, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix - ) +def build_kernel(grad_logp, dim: int, integrator, transform, params: Parameters): + step = integrator(T=update_position(grad_logp), V=update_momentum, dim=dim) + def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: + xx, uu, ll, gg, kinetic_change = step(state, params) # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + nu = jnp.sqrt((jnp.exp(2 * params.step_size / params.L) - 1.0) / dim) uu = partially_refresh_momentum(u=uu, random_key=rng_key, nu=nu) - de = kinetic_change + ll - l - return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) + return MCLMCState(xx, uu, ll, gg), MCLMCInfo( + transformed_x=transform(xx), + logdensity=ll, + dE=kinetic_change + ll - state.logdensity, + ) return kernel @@ -129,15 +122,17 @@ def kernel( def minimal_norm(dim, T, V): - def step(x, u, g, step_size, inverse_mass_matrix): + def step(state: MCLMCState, params: Parameters): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" # V T V T V - uu, r1 = V(step_size * lambda_c, u, g * inverse_mass_matrix) - xx, ll, gg = T(step_size, x, 0.5 * uu * inverse_mass_matrix) - uu, r2 = V(step_size * (1 - 2 * lambda_c), uu, gg * inverse_mass_matrix) - xx, ll, gg = T(step_size, xx, 0.5 * uu * inverse_mass_matrix) - uu, r3 = V(step_size * lambda_c, uu, gg * inverse_mass_matrix) + dt = params.step_size + sigma = params.inverse_mass_matrix + uu, r1 = V(dt * lambda_c, state.momentum, state.logdensity_grad * sigma) + xx, ll, gg = T(dt, state.position, 0.5 * uu * sigma) + uu, r2 = V(dt * (1 - 2 * lambda_c), uu, gg * sigma) + xx, ll, gg = T(dt, xx, 0.5 * uu * sigma) + uu, r3 = V(dt * lambda_c, uu, gg * sigma) # kinetic energy change kinetic_change = (r1 + r2 + r3) * (dim - 1) @@ -158,13 +153,14 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, dim: int, transform: Callable, + params : Parameters, integrator=minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, dim, integrator, transform) + kernel = cls.build_kernel(grad_logp, dim, integrator, transform, params) - def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): + def init_fn(position: ArrayLike, rng_key: PRNGKey): return cls.init(position, logdensity_fn, rng_key) return SamplingAlgorithm(init_fn, kernel) diff --git a/explore.py b/explore.py index b7be6addb..1bb1230dd 100644 --- a/explore.py +++ b/explore.py @@ -1,10 +1,9 @@ import jax import jax.numpy as jnp + import blackjax from blackjax.base import SamplingAlgorithm -import jax -import jax.numpy as jnp -import blackjax +from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.mclmc import Parameters @@ -16,16 +15,19 @@ def logdensity_fn(x): initial_position = jnp.array([1.0, 1.0]) +dim = 2 + mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, dim=2, transform=lambda x: x + logdensity_fn=logdensity_fn, + dim=dim, + transform=lambda x: x, + params=Parameters( + L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0]) + ), ) -params = Parameters( - L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0]) -) -# ? -# tuning() + def run_sampling_algorithm( @@ -34,11 +36,42 @@ def run_sampling_algorithm( keys = jax.random.split(rng_key, num_steps + 1) state = sampling_algorithm.init(initial_val, keys[0]) _, info = jax.lax.scan( - lambda s, k: (sampling_algorithm.step(k, s, params=params)), state, keys[1:] + lambda s, k: (sampling_algorithm.step(k, s)), state, keys[1:] ) return info +# ? +# tuning() +num_steps = 10000 +initial_params = Parameters(L=jnp.sqrt(dim),step_size=0.4*jnp.sqrt(dim), inverse_mass_matrix=jnp.array([1.0, 1.0])) +mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + dim=dim, + transform=lambda x: x, + params=initial_params +) +out = run_sampling_algorithm( + sampling_algorithm=mclmc, + num_steps= int(num_steps * 0.1), + initial_val=jnp.array([0.1, 0.1]), + rng_key=jax.random.PRNGKey(0), +) +Lfactor = 0.4 +ESS = effective_sample_size(out.transformed_x) +Lnew = Lfactor * initial_params.step_size / ESS # = 0.4 * correlation length +print(Lnew) +raise Exception + +# def tune3(self, x, u, l, g, random_key, L, eps, sigma, num_steps): +# """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" +# X, xx, uu, ll, gg, key = self.sample_full(num_steps, x, u, l, g, random_key, L, eps, sigma) +# ESS = ess_corr(X) +# Lnew = self.Lfactor * eps / ESS # = 0.4 * correlation length + +# return Lnew, xx, uu, ll, gg, key + + out = run_sampling_algorithm( sampling_algorithm=mclmc, num_steps=10000, @@ -49,3 +82,6 @@ def run_sampling_algorithm( print(jnp.mean(out.transformed_x, axis=0)) assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.0048037, 5.0181437]) + + + From 82b8466abe8b7058bd78c100425adfcfe623cb4b Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 23:02:14 +0100 Subject: [PATCH 07/78] wip --- blackjax/mcmc/mclmc.py | 117 +++++++++++------------ explore.py | 212 ++++++++++++++++++++++++++++++++++------- 2 files changed, 238 insertions(+), 91 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ac8001049..1bb95e8af 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -23,15 +23,6 @@ __all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc", "Parameters"] - -class Parameters(NamedTuple): - """Tunable parameters""" - - L: float - step_size: float - inverse_mass_matrix: Array - - MCLMCState = integrators.IntegratorState @@ -42,59 +33,19 @@ class MCLMCInfo(NamedTuple): logdensity: Array dE: float +class Parameters(NamedTuple): + """Tunable parameters""" -### -# helper funcs -### - - -def random_unit_vector(random_key, dim): - u = jax.random.normal(random_key, shape=(dim,)) - u /= jnp.sqrt(jnp.sum(jnp.square(u))) - return u - - -def update_position(grad_logp): - def update(step_size, x, u): - xx = x + step_size * u - ll, gg = grad_logp(xx) - return xx, ll, gg - - return update - - -def partially_refresh_momentum(u, random_key, nu): - """Adds a small noise to u and normalizes.""" - z = nu * jax.random.normal(random_key, shape=(u.shape[0],)) - return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - - -### -# integrator -### - - -def update_momentum(step_size, u, g): - """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) - similar to the implementation: https://github.com/gregversteeg/esh_dynamics - There are no exponentials e^delta, which prevents overflows when the gradient norm is large. - """ - g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) - e = g / g_norm - ue = jnp.dot(u, e) - dim = u.shape[0] - delta = step_size * g_norm / (dim - 1) - zeta = jnp.exp(-delta) - uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u - delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1 - ue) * zeta**2) - return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r + L: float + step_size: float + inverse_mass_matrix: Array -def init(x_initial: ArrayLike, logdensity_fn, random_key): +def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) return MCLMCState( position=x_initial, - momentum=random_unit_vector(random_key, dim=x_initial.shape[0]), + momentum=random_unit_vector(rng_key, dim=x_initial.shape[0]), logdensity=l, logdensity_grad=g, ) @@ -107,7 +58,7 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, params) # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * params.step_size / params.L) - 1.0) / dim) - uu = partially_refresh_momentum(u=uu, random_key=rng_key, nu=nu) + uu = partially_refresh_momentum(u=uu, rng_key=rng_key, nu=nu) return MCLMCState(xx, uu, ll, gg), MCLMCInfo( transformed_x=transform(xx), @@ -118,16 +69,16 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: return kernel -lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator def minimal_norm(dim, T, V): + lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator def step(state: MCLMCState, params: Parameters): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" # V T V T V dt = params.step_size - sigma = params.inverse_mass_matrix + sigma = jnp.sqrt(params.inverse_mass_matrix) uu, r1 = V(dt * lambda_c, state.momentum, state.logdensity_grad * sigma) xx, ll, gg = T(dt, state.position, 0.5 * uu * sigma) uu, r2 = V(dt * (1 - 2 * lambda_c), uu, gg * sigma) @@ -164,3 +115,51 @@ def init_fn(position: ArrayLike, rng_key: PRNGKey): return cls.init(position, logdensity_fn, rng_key) return SamplingAlgorithm(init_fn, kernel) + + +### +# helper funcs +### + + +def random_unit_vector(rng_key, dim): + u = jax.random.normal(rng_key, shape=(dim,)) + u /= jnp.sqrt(jnp.sum(jnp.square(u))) + return u + + +def update_position(grad_logp): + def update(step_size, x, u): + xx = x + step_size * u + ll, gg = grad_logp(xx) + return xx, ll, gg + + return update + + +def partially_refresh_momentum(u, rng_key, nu): + """Adds a small noise to u and normalizes.""" + z = nu * jax.random.normal(rng_key, shape=(u.shape[0],)) + return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) + + +### +# integrator +### + + +def update_momentum(step_size, u, g): + """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) + similar to the implementation: https://github.com/gregversteeg/esh_dynamics + There are no exponentials e^delta, which prevents overflows when the gradient norm is large. + """ + g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) + e = g / g_norm + ue = jnp.dot(u, e) + dim = u.shape[0] + delta = step_size * g_norm / (dim - 1) + zeta = jnp.exp(-delta) + uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u + delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1 - ue) * zeta**2) + return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r + diff --git a/explore.py b/explore.py index 1bb1230dd..88aab2faf 100644 --- a/explore.py +++ b/explore.py @@ -5,6 +5,8 @@ from blackjax.base import SamplingAlgorithm from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.mclmc import Parameters +from blackjax.types import PRNGKey +from scipy.fftpack import next_fast_len #type: ignore def logdensity_fn(x): @@ -17,14 +19,7 @@ def logdensity_fn(x): dim = 2 -mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - dim=dim, - transform=lambda x: x, - params=Parameters( - L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0]) - ), -) + @@ -41,47 +36,200 @@ def run_sampling_algorithm( return info +key = jax.random.PRNGKey(0) +main_key, tune_key = jax.random.split(key) + +def ess_corr(x): + """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html + shape(x) = (num_samples, d)""" + + input_array = jnp.array([x, ]) + + num_chains = 1#input_array.shape[0] + num_samples = input_array.shape[1] + + mean_across_chain = input_array.mean(axis=1, keepdims=True) + # Compute autocovariance estimates for every lag for the input array using FFT. + centered_array = input_array - mean_across_chain + m = next_fast_len(2 * num_samples) + ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) + ifft_ary *= jnp.conjugate(ifft_ary) + autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) + autocov_value = ( + jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples + ) + mean_autocov_var = autocov_value.mean(0, keepdims=True) + mean_var0 = (jnp.take(mean_autocov_var, jnp.array([0]), axis=1) * num_samples / (num_samples - 1.0)) + weighted_var = mean_var0 * (num_samples - 1.0) / num_samples + weighted_var = jax.lax.cond( + num_chains > 1, + lambda _: weighted_var+ mean_across_chain.var(axis=0, ddof=1, keepdims=True), + lambda _: weighted_var, + operand=None, + ) + + # Geyer's initial positive sequence + num_samples_even = num_samples - num_samples % 2 + mean_autocov_var_tp1 = jnp.take(mean_autocov_var, jnp.arange(1, num_samples_even), axis=1) + rho_hat = jnp.concatenate([jnp.ones_like(mean_var0), 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var,], axis=1,) + + rho_hat = jnp.moveaxis(rho_hat, 1, 0) + rho_hat_even = rho_hat[0::2] + rho_hat_odd = rho_hat[1::2] + + mask0 = (rho_hat_even + rho_hat_odd) > 0.0 + carry_cond = jnp.ones_like(mask0[0]) + max_t = jnp.zeros_like(mask0[0], dtype=int) + + def positive_sequence_body_fn(state, mask_t): + t, carry_cond, max_t = state + next_mask = carry_cond & mask_t + next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) + return (t + 1, next_mask, next_max_t), next_mask + + (*_, max_t_next), mask = jax.lax.scan( + positive_sequence_body_fn, (0, carry_cond, max_t), mask0 + ) + indices = jnp.indices(max_t_next.shape) + indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) + rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) + # improve estimation + mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) + rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) + + # Geyer's initial monotone sequence + def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): + update_mask = rho_hat_sum_t > rho_hat_sum_tm1 + next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) + return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) + + rho_hat_sum = rho_hat_even + rho_hat_odd + _, (update_mask, update_value) = jax.lax.scan( + monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum + ) + + rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) + rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) + + # compute effective sample size + ess_raw = num_chains * num_samples + tau_hat = (-1.0 + + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) + - rho_hat_even_final[indices] + ) + + tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) + ess = ess_raw / tau_hat + + ### my part (combine all dimensions): ### + neff = ess.squeeze() / num_samples + return 1.0 / jnp.average(1 / neff) + + # ? # tuning() num_steps = 10000 -initial_params = Parameters(L=jnp.sqrt(dim),step_size=0.4*jnp.sqrt(dim), inverse_mass_matrix=jnp.array([1.0, 1.0])) +initial_params = Parameters(1.3852125, 1.0604926, jnp.array([1., 1.])) +# Parameters(L=jnp.sqrt(dim),step_size=0.4*jnp.sqrt(dim), inverse_mass_matrix=jnp.array([1.0, 1.0])) +def tune(num_steps : int, params : Parameters, rng_key : PRNGKey) -> Parameters: + + # steps1 = (int)(num_steps * 0.1) + # steps2 = (int)(num_steps * 0.1) + # def tune12(self, x, u, l, g, random_key, L_given, eps, sigma_given, num_steps1, num_steps2): + # """cheap hyperparameter tuning""" + + + # def step(state, outer_weight): + # """one adaptive step of the dynamics""" + # x, u, l, g, E, Feps, Weps, eps_max, key, eps = self.dynamics_adaptive(state[0], L, sigma) + # W, F1, F2 = state[1] + # w = outer_weight * eps + # zero_prevention = 1-outer_weight + # F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter + # F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter + # W += w + + # return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps + + # L = L_given + + # # we use the last num_steps2 to compute the diagonal preconditioner + # outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + # #initial state + # state = ((x, u, l, g, 0., jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), (0., jnp.zeros(len(x)), jnp.zeros(len(x)))) + # # run the steps + # state, eps = jax.lax.scan(step, init=state, xs= outer_weights, length= num_steps1 + num_steps2) + # # determine L + # if num_steps2 != 0.: + # F1, F2 = state[1][1], state[1][2] + # variances = F2 - jnp.square(F1) + # sigma2 = jnp.average(variances) + + # # optionally we do the diagonal preconditioning (and readjust the stepsize) + # if self.diagonal_preconditioning: + + # # diagonal preconditioning + # sigma = jnp.sqrt(variances) + # L = jnp.sqrt(self.Target.d) + + # #readjust the stepsize + # steps = num_steps2 // 3 #we do some small number of steps + # state, eps = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps) + # else: + # L = jnp.sqrt(sigma2 * self.Target.d) + + # xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state + # return L, eps[-1], sigma, xx, uu, ll, gg, key #return the tuned hyperparameters and the final state + + + # params = + + mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + dim=dim, + transform=lambda x: x, + params=params + ) + out = run_sampling_algorithm( + sampling_algorithm=mclmc, + num_steps= int(num_steps * 0.1), + initial_val=jnp.array([0.1, 0.1]), + rng_key=rng_key, + ) + Lfactor = 0.4 + # ESS = effective_sample_size(out.transformed_x) + ESS = ess_corr(out.transformed_x) + # neff = ESS / num_steps + # ESS = 1.0 / jnp.average(1 / neff) + # print(f"Ess is {ESS}") + # print(f"Ess is {ESS2}") + # print(out.transformed_x) + Lnew = Lfactor * initial_params.step_size / ESS + return Parameters(L=Lnew, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix) + +print(tune(num_steps=10000, params=initial_params, rng_key=tune_key)) + + mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, dim=dim, transform=lambda x: x, - params=initial_params -) -out = run_sampling_algorithm( - sampling_algorithm=mclmc, - num_steps= int(num_steps * 0.1), - initial_val=jnp.array([0.1, 0.1]), - rng_key=jax.random.PRNGKey(0), + params=Parameters( + L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0]) + ), ) -Lfactor = 0.4 -ESS = effective_sample_size(out.transformed_x) -Lnew = Lfactor * initial_params.step_size / ESS # = 0.4 * correlation length -print(Lnew) -raise Exception - -# def tune3(self, x, u, l, g, random_key, L, eps, sigma, num_steps): -# """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" -# X, xx, uu, ll, gg, key = self.sample_full(num_steps, x, u, l, g, random_key, L, eps, sigma) -# ESS = ess_corr(X) -# Lnew = self.Lfactor * eps / ESS # = 0.4 * correlation length - -# return Lnew, xx, uu, ll, gg, key - out = run_sampling_algorithm( sampling_algorithm=mclmc, num_steps=10000, initial_val=jnp.array([0.1, 0.1]), - rng_key=jax.random.PRNGKey(0), + rng_key=main_key, ) print(jnp.mean(out.transformed_x, axis=0)) -assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.0048037, 5.0181437]) +assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.0377626, 4.9752364]) From a4d403baad746571d7fe6a3da3743cc4548ac874 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 23:18:31 +0100 Subject: [PATCH 08/78] fix pre-commit --- blackjax/diagnostics.py | 2 +- blackjax/mcmc/mclmc.py | 11 +- explore.py | 235 ---------------------------------------- 3 files changed, 6 insertions(+), 242 deletions(-) delete mode 100644 explore.py diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index ed7ad5bbc..da861d9b1 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -206,4 +206,4 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): tau_hat = jnp.maximum(tau_hat, 1 / np.log10(ess_raw)) ess = ess_raw / tau_hat - return ess.squeeze() \ No newline at end of file + return ess.squeeze() diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 1bb95e8af..8ed80d6d2 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -33,6 +33,7 @@ class MCLMCInfo(NamedTuple): logdensity: Array dE: float + class Parameters(NamedTuple): """Tunable parameters""" @@ -69,10 +70,9 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: return kernel - - def minimal_norm(dim, T, V): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator + def step(state: MCLMCState, params: Parameters): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" @@ -104,15 +104,15 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, dim: int, transform: Callable, - params : Parameters, + params: Parameters, integrator=minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) kernel = cls.build_kernel(grad_logp, dim, integrator, transform, params) - def init_fn(position: ArrayLike, rng_key: PRNGKey): - return cls.init(position, logdensity_fn, rng_key) + def init_fn(position: ArrayLike): + return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) return SamplingAlgorithm(init_fn, kernel) @@ -162,4 +162,3 @@ def update_momentum(step_size, u, g): uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1 - ue) * zeta**2) return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r - diff --git a/explore.py b/explore.py deleted file mode 100644 index 88aab2faf..000000000 --- a/explore.py +++ /dev/null @@ -1,235 +0,0 @@ -import jax -import jax.numpy as jnp - -import blackjax -from blackjax.base import SamplingAlgorithm -from blackjax.diagnostics import effective_sample_size -from blackjax.mcmc.mclmc import Parameters -from blackjax.types import PRNGKey -from scipy.fftpack import next_fast_len #type: ignore - - -def logdensity_fn(x): - return -0.5 * jnp.sum(jnp.square(x - 5)) - - -# Initialize the state -initial_position = jnp.array([1.0, 1.0]) - - -dim = 2 - - - - - - - -def run_sampling_algorithm( - sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key -): - keys = jax.random.split(rng_key, num_steps + 1) - state = sampling_algorithm.init(initial_val, keys[0]) - _, info = jax.lax.scan( - lambda s, k: (sampling_algorithm.step(k, s)), state, keys[1:] - ) - return info - - -key = jax.random.PRNGKey(0) -main_key, tune_key = jax.random.split(key) - -def ess_corr(x): - """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html - shape(x) = (num_samples, d)""" - - input_array = jnp.array([x, ]) - - num_chains = 1#input_array.shape[0] - num_samples = input_array.shape[1] - - mean_across_chain = input_array.mean(axis=1, keepdims=True) - # Compute autocovariance estimates for every lag for the input array using FFT. - centered_array = input_array - mean_across_chain - m = next_fast_len(2 * num_samples) - ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) - ifft_ary *= jnp.conjugate(ifft_ary) - autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) - autocov_value = ( - jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples - ) - mean_autocov_var = autocov_value.mean(0, keepdims=True) - mean_var0 = (jnp.take(mean_autocov_var, jnp.array([0]), axis=1) * num_samples / (num_samples - 1.0)) - weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - weighted_var = jax.lax.cond( - num_chains > 1, - lambda _: weighted_var+ mean_across_chain.var(axis=0, ddof=1, keepdims=True), - lambda _: weighted_var, - operand=None, - ) - - # Geyer's initial positive sequence - num_samples_even = num_samples - num_samples % 2 - mean_autocov_var_tp1 = jnp.take(mean_autocov_var, jnp.arange(1, num_samples_even), axis=1) - rho_hat = jnp.concatenate([jnp.ones_like(mean_var0), 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var,], axis=1,) - - rho_hat = jnp.moveaxis(rho_hat, 1, 0) - rho_hat_even = rho_hat[0::2] - rho_hat_odd = rho_hat[1::2] - - mask0 = (rho_hat_even + rho_hat_odd) > 0.0 - carry_cond = jnp.ones_like(mask0[0]) - max_t = jnp.zeros_like(mask0[0], dtype=int) - - def positive_sequence_body_fn(state, mask_t): - t, carry_cond, max_t = state - next_mask = carry_cond & mask_t - next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) - return (t + 1, next_mask, next_max_t), next_mask - - (*_, max_t_next), mask = jax.lax.scan( - positive_sequence_body_fn, (0, carry_cond, max_t), mask0 - ) - indices = jnp.indices(max_t_next.shape) - indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) - rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) - # improve estimation - mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) - rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) - - # Geyer's initial monotone sequence - def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): - update_mask = rho_hat_sum_t > rho_hat_sum_tm1 - next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) - return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) - - rho_hat_sum = rho_hat_even + rho_hat_odd - _, (update_mask, update_value) = jax.lax.scan( - monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum - ) - - rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) - rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) - - # compute effective sample size - ess_raw = num_chains * num_samples - tau_hat = (-1.0 - + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) - - rho_hat_even_final[indices] - ) - - tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) - ess = ess_raw / tau_hat - - ### my part (combine all dimensions): ### - neff = ess.squeeze() / num_samples - return 1.0 / jnp.average(1 / neff) - - -# ? -# tuning() -num_steps = 10000 -initial_params = Parameters(1.3852125, 1.0604926, jnp.array([1., 1.])) -# Parameters(L=jnp.sqrt(dim),step_size=0.4*jnp.sqrt(dim), inverse_mass_matrix=jnp.array([1.0, 1.0])) -def tune(num_steps : int, params : Parameters, rng_key : PRNGKey) -> Parameters: - - # steps1 = (int)(num_steps * 0.1) - # steps2 = (int)(num_steps * 0.1) - # def tune12(self, x, u, l, g, random_key, L_given, eps, sigma_given, num_steps1, num_steps2): - # """cheap hyperparameter tuning""" - - - # def step(state, outer_weight): - # """one adaptive step of the dynamics""" - # x, u, l, g, E, Feps, Weps, eps_max, key, eps = self.dynamics_adaptive(state[0], L, sigma) - # W, F1, F2 = state[1] - # w = outer_weight * eps - # zero_prevention = 1-outer_weight - # F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter - # F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter - # W += w - - # return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps - - # L = L_given - - # # we use the last num_steps2 to compute the diagonal preconditioner - # outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - - # #initial state - # state = ((x, u, l, g, 0., jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), (0., jnp.zeros(len(x)), jnp.zeros(len(x)))) - # # run the steps - # state, eps = jax.lax.scan(step, init=state, xs= outer_weights, length= num_steps1 + num_steps2) - # # determine L - # if num_steps2 != 0.: - # F1, F2 = state[1][1], state[1][2] - # variances = F2 - jnp.square(F1) - # sigma2 = jnp.average(variances) - - # # optionally we do the diagonal preconditioning (and readjust the stepsize) - # if self.diagonal_preconditioning: - - # # diagonal preconditioning - # sigma = jnp.sqrt(variances) - # L = jnp.sqrt(self.Target.d) - - # #readjust the stepsize - # steps = num_steps2 // 3 #we do some small number of steps - # state, eps = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps) - # else: - # L = jnp.sqrt(sigma2 * self.Target.d) - - # xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state - # return L, eps[-1], sigma, xx, uu, ll, gg, key #return the tuned hyperparameters and the final state - - - # params = - - mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - dim=dim, - transform=lambda x: x, - params=params - ) - out = run_sampling_algorithm( - sampling_algorithm=mclmc, - num_steps= int(num_steps * 0.1), - initial_val=jnp.array([0.1, 0.1]), - rng_key=rng_key, - ) - Lfactor = 0.4 - # ESS = effective_sample_size(out.transformed_x) - ESS = ess_corr(out.transformed_x) - # neff = ESS / num_steps - # ESS = 1.0 / jnp.average(1 / neff) - # print(f"Ess is {ESS}") - # print(f"Ess is {ESS2}") - # print(out.transformed_x) - Lnew = Lfactor * initial_params.step_size / ESS - return Parameters(L=Lnew, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix) - -print(tune(num_steps=10000, params=initial_params, rng_key=tune_key)) - - -mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - dim=dim, - transform=lambda x: x, - params=Parameters( - L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0]) - ), -) - -out = run_sampling_algorithm( - sampling_algorithm=mclmc, - num_steps=10000, - initial_val=jnp.array([0.1, 0.1]), - rng_key=main_key, -) - -print(jnp.mean(out.transformed_x, axis=0)) - -assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [5.0377626, 4.9752364]) - - - From a67ecb7612a86e5b7d8aead226b0d559b1ac7a7f Mon Sep 17 00:00:00 2001 From: = Date: Sat, 11 Nov 2023 23:22:31 +0100 Subject: [PATCH 09/78] remove dim from class --- blackjax/mcmc/mclmc.py | 70 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 8ed80d6d2..f0da90ebe 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -52,11 +52,12 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(grad_logp, dim: int, integrator, transform, params: Parameters): - step = integrator(T=update_position(grad_logp), V=update_momentum, dim=dim) +def build_kernel(grad_logp, integrator, transform, params: Parameters): + step = integrator(T=update_position(grad_logp), V=update_momentum) def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, params) + dim = xx.shape[0] # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * params.step_size / params.L) - 1.0) / dim) uu = partially_refresh_momentum(u=uu, rng_key=rng_key, nu=nu) @@ -70,7 +71,7 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: return kernel -def minimal_norm(dim, T, V): +def minimal_norm(T, V): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator def step(state: MCLMCState, params: Parameters): @@ -86,6 +87,7 @@ def step(state: MCLMCState, params: Parameters): uu, r3 = V(dt * lambda_c, uu, gg * sigma) # kinetic energy change + dim = xx.shape[0] kinetic_change = (r1 + r2 + r3) * (dim - 1) return xx, uu, ll, gg, kinetic_change @@ -94,7 +96,64 @@ def step(state: MCLMCState, params: Parameters): class mclmc: - """todo: add documentation""" + """The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias `blackjax.hmc.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.hmc` to SMC, adaptation, etc. algorithms. + + Examples + -------- + + A new HMC kernel can be initialized and used with the following code: + + .. code:: + + hmc = blackjax.hmc(logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps) + state = hmc.init(position) + new_state, info = hmc.step(rng_key, state) + + Kernels are not jit-compiled by default so you will need to do it manually: + + .. code:: + + step = jax.jit(hmc.step) + new_state, info = step(rng_key, state) + + Should you need to you can always use the base kernel directly: + + .. code:: + + import blackjax.mcmc.integrators as integrators + + kernel = blackjax.hmc.build_kernel(integrators.mclachlan) + state = blackjax.hmc.init(position, logdensity_fn) + state, info = kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + TODO + transform + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. + num_integration_steps + The number of steps we take with the symplectic integrator at each + sample step before returning a sample. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory.\ + + Returns + ------- + A ``SamplingAlgorithm``. + """ init = staticmethod(init) build_kernel = staticmethod(build_kernel) @@ -102,14 +161,13 @@ class mclmc: def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - dim: int, transform: Callable, params: Parameters, integrator=minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, dim, integrator, transform, params) + kernel = cls.build_kernel(grad_logp, integrator, transform, params) def init_fn(position: ArrayLike): return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) From 3dd4f74f44fddd8db49249b905be6192b79f65fb Mon Sep 17 00:00:00 2001 From: = Date: Sun, 12 Nov 2023 00:10:15 +0100 Subject: [PATCH 10/78] add docstrings --- blackjax/mcmc/mclmc.py | 77 ++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 33 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index f0da90ebe..ffcf1f159 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -25,15 +25,22 @@ MCLMCState = integrators.IntegratorState - class MCLMCInfo(NamedTuple): - """Additional information on the MCLMC transition.""" + """Additional information on the MCLMC transition. + + transformed_x + The value of the samples after a transformation (e.g. projection onto lower dim subspace) + logdensity + logdensity at given step + dE + energy difference + + """ transformed_x: Array logdensity: Array dE: float - class Parameters(NamedTuple): """Tunable parameters""" @@ -41,7 +48,6 @@ class Parameters(NamedTuple): step_size: float inverse_mass_matrix: Array - def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) return MCLMCState( @@ -51,8 +57,26 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) - def build_kernel(grad_logp, integrator, transform, params: Parameters): + + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Hamiltonian dynamics. + transform + Value of the difference in energy above which we consider that the transition is divergent. + params + Parameters + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ step = integrator(T=update_position(grad_logp), V=update_momentum) def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: @@ -70,7 +94,6 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: return kernel - def minimal_norm(T, V): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator @@ -96,59 +119,47 @@ def step(state: MCLMCState, params: Parameters): class mclmc: - """The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias `blackjax.hmc.build_kernel`) can be + """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that specializes the general kernel. We also add the general kernel and state generator as an attribute to this class so - users only need to pass `blackjax.hmc` to SMC, adaptation, etc. algorithms. + users only need to pass `blackjax.mclmc` to SMC, adaptation, etc. algorithms. Examples -------- - A new HMC kernel can be initialized and used with the following code: + A new mclmc kernel can be initialized and used with the following code: .. code:: - hmc = blackjax.hmc(logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps) - state = hmc.init(position) - new_state, info = hmc.step(rng_key, state) + mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + transform=lambda x: x, + params=params + ) + state = mclmc.init(position) + new_state, info = mclmc.step(rng_key, state) Kernels are not jit-compiled by default so you will need to do it manually: .. code:: - step = jax.jit(hmc.step) + step = jax.jit(mclmc.step) new_state, info = step(rng_key, state) - Should you need to you can always use the base kernel directly: - - .. code:: - - import blackjax.mcmc.integrators as integrators - - kernel = blackjax.hmc.build_kernel(integrators.mclachlan) - state = blackjax.hmc.init(position, logdensity_fn) - state, info = kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps) - Parameters ---------- logdensity_fn The log-density function we wish to draw samples from. - TODO - transform + transform The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy. - num_integration_steps - The number of steps we take with the symplectic integrator at each - sample step before returning a sample. - divergence_threshold - The absolute value of the difference in energy between two states above - which we say that the transition is divergent. The default value is - commonly found in other libraries, and yet is arbitrary. + params + Paramters integrator - (algorithm parameter) The symplectic integrator to use to integrate the trajectory.\ + an integrator. We recommend using the default here. Returns ------- From 5d8061db4b0f42bc5890af01b4973144563266f6 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 Nov 2023 16:25:48 +0100 Subject: [PATCH 11/78] add mclmc to init --- blackjax/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 9016d2a0e..ae53a235d 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -9,6 +9,7 @@ from .mcmc.ghmc import ghmc from .mcmc.hmc import hmc from .mcmc.mala import mala +from .mcmc.mclmc import mclmc from .mcmc.marginal_latent_gaussian import mgrad_gaussian from .mcmc.nuts import nuts from .mcmc.periodic_orbital import orbital_hmc @@ -36,6 +37,7 @@ "additive_step_random_walk", "rmh", "irmh", + "mclmc", "elliptical_slice", "ghmc", "sgld", # stochastic gradient mcmc From 2bf639e7e5081107b4e7098c2ade1fdbf4c33d66 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 Nov 2023 19:02:46 +0100 Subject: [PATCH 12/78] move minimal_norm to integrators --- blackjax/mcmc/integrators.py | 22 ++++++++++++++++++++++ blackjax/mcmc/mclmc.py | 28 +++------------------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 09946e9a3..15be9487c 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -246,3 +246,25 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return IntegratorState(position, momentum, logdensity, logdensity_grad) return one_step + +def minimal_norm(T, V, inverse_mass_matrix): + lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator + + def step(state: IntegratorState, step_size): + """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" + + # V T V T V + sigma = jax.numpy.sqrt(inverse_mass_matrix) + uu, r1 = V(step_size * lambda_c, state.momentum, state.logdensity_grad * sigma) + xx, ll, gg = T(step_size, state.position, 0.5 * uu * sigma) + uu, r2 = V(step_size * (1 - 2 * lambda_c), uu, gg * sigma) + xx, ll, gg = T(step_size, xx, 0.5 * uu * sigma) + uu, r3 = V(step_size * lambda_c, uu, gg * sigma) + + # kinetic energy change + dim = xx.shape[0] + kinetic_change = (r1 + r2 + r3) * (dim - 1) + + return xx, uu, ll, gg, kinetic_change + + return step \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ffcf1f159..cf652f6a2 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -77,10 +77,10 @@ def build_kernel(grad_logp, integrator, transform, params: Parameters): information about the transition. """ - step = integrator(T=update_position(grad_logp), V=update_momentum) + step = integrator(T=update_position(grad_logp), V=update_momentum, inverse_mass_matrix=params.inverse_mass_matrix) def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: - xx, uu, ll, gg, kinetic_change = step(state, params) + xx, uu, ll, gg, kinetic_change = step(state, params.step_size) dim = xx.shape[0] # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * params.step_size / params.L) - 1.0) / dim) @@ -94,28 +94,6 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: return kernel -def minimal_norm(T, V): - lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator - - def step(state: MCLMCState, params: Parameters): - """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" - - # V T V T V - dt = params.step_size - sigma = jnp.sqrt(params.inverse_mass_matrix) - uu, r1 = V(dt * lambda_c, state.momentum, state.logdensity_grad * sigma) - xx, ll, gg = T(dt, state.position, 0.5 * uu * sigma) - uu, r2 = V(dt * (1 - 2 * lambda_c), uu, gg * sigma) - xx, ll, gg = T(dt, xx, 0.5 * uu * sigma) - uu, r3 = V(dt * lambda_c, uu, gg * sigma) - - # kinetic energy change - dim = xx.shape[0] - kinetic_change = (r1 + r2 + r3) * (dim - 1) - - return xx, uu, ll, gg, kinetic_change - - return step class mclmc: @@ -174,7 +152,7 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, transform: Callable, params: Parameters, - integrator=minimal_norm, + integrator=integrators.minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) From 172fee0993e95b2879c05449eb1ff055f7bd9099 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 Nov 2023 19:08:29 +0100 Subject: [PATCH 13/78] move update pos and momentum --- blackjax/mcmc/integrators.py | 30 +++++++++++++++++++++++++++++- blackjax/mcmc/mclmc.py | 29 +---------------------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 15be9487c..00e2ad2f0 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -267,4 +267,32 @@ def step(state: IntegratorState, step_size): return xx, uu, ll, gg, kinetic_change - return step \ No newline at end of file + return step + + + + +def update_position_mclmc(grad_logp): + """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) + """ + def update(step_size, x, u): + xx = x + step_size * u + ll, gg = grad_logp(xx) + return xx, ll, gg + + return update + +def update_momentum_mclmc(step_size, u, g): + """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) + similar to the implementation: https://github.com/gregversteeg/esh_dynamics + There are no exponentials e^delta, which prevents overflows when the gradient norm is large. + """ + g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) + e = g / g_norm + ue = jax.numpy.dot(u, e) + dim = u.shape[0] + delta = step_size * g_norm / (dim - 1) + zeta = jax.numpy.exp(-delta) + uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u + delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) + return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index cf652f6a2..d4970bad5 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -77,7 +77,7 @@ def build_kernel(grad_logp, integrator, transform, params: Parameters): information about the transition. """ - step = integrator(T=update_position(grad_logp), V=update_momentum, inverse_mass_matrix=params.inverse_mass_matrix) + step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc, inverse_mass_matrix=params.inverse_mass_matrix) def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, params.step_size) @@ -175,13 +175,6 @@ def random_unit_vector(rng_key, dim): return u -def update_position(grad_logp): - def update(step_size, x, u): - xx = x + step_size * u - ll, gg = grad_logp(xx) - return xx, ll, gg - - return update def partially_refresh_momentum(u, rng_key, nu): @@ -189,23 +182,3 @@ def partially_refresh_momentum(u, rng_key, nu): z = nu * jax.random.normal(rng_key, shape=(u.shape[0],)) return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - -### -# integrator -### - - -def update_momentum(step_size, u, g): - """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) - similar to the implementation: https://github.com/gregversteeg/esh_dynamics - There are no exponentials e^delta, which prevents overflows when the gradient norm is large. - """ - g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) - e = g / g_norm - ue = jnp.dot(u, e) - dim = u.shape[0] - delta = step_size * g_norm / (dim - 1) - zeta = jnp.exp(-delta) - uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u - delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1 - ue) * zeta**2) - return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r From b710e62d1e827432f0e22a7799881bfd6dc7e3d2 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 Nov 2023 19:18:04 +0100 Subject: [PATCH 14/78] remove params --- blackjax/mcmc/mclmc.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index d4970bad5..310093a64 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -41,12 +41,6 @@ class MCLMCInfo(NamedTuple): logdensity: Array dE: float -class Parameters(NamedTuple): - """Tunable parameters""" - - L: float - step_size: float - inverse_mass_matrix: Array def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) @@ -57,7 +51,7 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) -def build_kernel(grad_logp, integrator, transform, params: Parameters): +def build_kernel(grad_logp, integrator, transform, L, step_size, inverse_mass_matrix): """Build a HMC kernel. @@ -67,8 +61,11 @@ def build_kernel(grad_logp, integrator, transform, params: Parameters): The symplectic integrator to use to integrate the Hamiltonian dynamics. transform Value of the difference in energy above which we consider that the transition is divergent. - params - Parameters + L + the momentum decoherence rate + step_size + step size of the integrator + inverse mass matrix Returns ------- @@ -77,13 +74,13 @@ def build_kernel(grad_logp, integrator, transform, params: Parameters): information about the transition. """ - step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc, inverse_mass_matrix=params.inverse_mass_matrix) + step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc, inverse_mass_matrix=inverse_mass_matrix) def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: - xx, uu, ll, gg, kinetic_change = step(state, params.step_size) + xx, uu, ll, gg, kinetic_change = step(state, step_size) dim = xx.shape[0] # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * params.step_size / params.L) - 1.0) / dim) + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) uu = partially_refresh_momentum(u=uu, rng_key=rng_key, nu=nu) return MCLMCState(xx, uu, ll, gg), MCLMCInfo( @@ -115,7 +112,9 @@ class mclmc: mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, transform=lambda x: x, - params=params + L=L, + step_size=step_size + inverse mass matrix=inverse_mass_matrix ) state = mclmc.init(position) new_state, info = mclmc.step(rng_key, state) @@ -134,7 +133,11 @@ class mclmc: transform The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy. - params + L + the momentum decoherence rate + step_size + step size of the integrator + inverse mass matrix Paramters integrator an integrator. We recommend using the default here. @@ -151,12 +154,14 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, transform: Callable, - params: Parameters, + L, + step_size, + inverse_mass_matrix, integrator=integrators.minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, integrator, transform, params) + kernel = cls.build_kernel(grad_logp, integrator, transform, L, step_size, inverse_mass_matrix) def init_fn(position: ArrayLike): return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) From 3cc52fd6b7d59aac74cc46a0bf39df33ba6687a1 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 14 Nov 2023 12:22:55 +0100 Subject: [PATCH 15/78] Infer the shape from inverse_mass_matrix outside the function step --- blackjax/mcmc/integrators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 00e2ad2f0..48e8c66b4 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -249,6 +249,7 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: def minimal_norm(T, V, inverse_mass_matrix): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator + dim = inverse_mass_matrix.shape[0] def step(state: IntegratorState, step_size): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" @@ -262,7 +263,6 @@ def step(state: IntegratorState, step_size): uu, r3 = V(step_size * lambda_c, uu, gg * sigma) # kinetic energy change - dim = xx.shape[0] kinetic_change = (r1 + r2 + r3) * (dim - 1) return xx, uu, ll, gg, kinetic_change From 57d5c3b097fb7677407f4c90a7bb3828fd209935 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 14 Nov 2023 18:05:23 +0100 Subject: [PATCH 16/78] use tree_map --- blackjax/mcmc/integrators.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 48e8c66b4..dd0154787 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -256,11 +256,11 @@ def step(state: IntegratorState, step_size): # V T V T V sigma = jax.numpy.sqrt(inverse_mass_matrix) - uu, r1 = V(step_size * lambda_c, state.momentum, state.logdensity_grad * sigma) - xx, ll, gg = T(step_size, state.position, 0.5 * uu * sigma) - uu, r2 = V(step_size * (1 - 2 * lambda_c), uu, gg * sigma) - xx, ll, gg = T(step_size, xx, 0.5 * uu * sigma) - uu, r3 = V(step_size * lambda_c, uu, gg * sigma) + uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g * sigma), state.momentum, state.logdensity_grad) + xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u * sigma), state.position, uu) + uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g * sigma), uu, gg) + xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u * sigma), xx, uu) + uu, r3 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g * sigma), uu, gg) # kinetic energy change kinetic_change = (r1 + r2 + r3) * (dim - 1) @@ -296,3 +296,4 @@ def update_momentum_mclmc(step_size, u, g): uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r + From 7e70d783eadf22bb4e36335ffad96631820120ec Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Nov 2023 18:14:59 +0100 Subject: [PATCH 17/78] integration now aligned with mclmc repo --- blackjax/mcmc/mclmc.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 310093a64..1e72efb46 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -38,12 +38,14 @@ class MCLMCInfo(NamedTuple): """ transformed_x: Array - logdensity: Array + logdensity: float dE: float def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) + jax.debug.print("🤯 {x} initial momentum 🤯", x=random_unit_vector(rng_key, dim=x_initial.shape[0])) + return MCLMCState( position=x_initial, momentum=random_unit_vector(rng_key, dim=x_initial.shape[0]), @@ -78,6 +80,9 @@ def build_kernel(grad_logp, integrator, transform, L, step_size, inverse_mass_ma def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, step_size) + jax.debug.print("🤯 {x} new 🤯", x=ll) + + dim = xx.shape[0] # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) From 1343463e992c6b5a1a78e0821a4dd77cad8dda3b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Nov 2023 18:21:52 +0100 Subject: [PATCH 18/78] dE and logdensity align too (fixed sign error) --- blackjax/mcmc/mclmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 1e72efb46..3a8a4c0b8 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -91,7 +91,7 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: return MCLMCState(xx, uu, ll, gg), MCLMCInfo( transformed_x=transform(xx), logdensity=ll, - dE=kinetic_change + ll - state.logdensity, + dE=kinetic_change - ll + state.logdensity, ) return kernel From e53a877cd1379181f0ee76a10dfbfe5574d54f2d Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Nov 2023 19:59:38 +0100 Subject: [PATCH 19/78] make L and step size arguments to kernel --- blackjax/mcmc/mclmc.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 3a8a4c0b8..8c6e376f8 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -39,12 +39,13 @@ class MCLMCInfo(NamedTuple): transformed_x: Array logdensity: float + kinetic_change: float dE: float def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) - jax.debug.print("🤯 {x} initial momentum 🤯", x=random_unit_vector(rng_key, dim=x_initial.shape[0])) + # jax.debug.print("🤯 {x} initial momentum 🤯", x=random_unit_vector(rng_key, dim=x_initial.shape[0])) return MCLMCState( position=x_initial, @@ -53,7 +54,7 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) -def build_kernel(grad_logp, integrator, transform, L, step_size, inverse_mass_matrix): +def build_kernel(grad_logp, integrator, transform, inverse_mass_matrix): """Build a HMC kernel. @@ -78,9 +79,9 @@ def build_kernel(grad_logp, integrator, transform, L, step_size, inverse_mass_ma """ step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc, inverse_mass_matrix=inverse_mass_matrix) - def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: + def kernel(rng_key: PRNGKey, state: MCLMCState, L : float, step_size : float) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, step_size) - jax.debug.print("🤯 {x} new 🤯", x=ll) + # jax.debug.print("🤯 {x} new 🤯", x=(kinetic_change, ll, state.logdensity)) dim = xx.shape[0] @@ -92,6 +93,7 @@ def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: transformed_x=transform(xx), logdensity=ll, dE=kinetic_change - ll + state.logdensity, + kinetic_change=kinetic_change ) return kernel @@ -166,12 +168,15 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, integrator, transform, L, step_size, inverse_mass_matrix) + kernel = cls.build_kernel(grad_logp, integrator, transform, inverse_mass_matrix) + + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) def init_fn(position: ArrayLike): return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) - return SamplingAlgorithm(init_fn, kernel) + return SamplingAlgorithm(init_fn, update_fn) ### From 05517b679fd75d79023f7a4680249e0f66195efd Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Nov 2023 22:21:13 +0100 Subject: [PATCH 20/78] rough draft of tuning: works --- blackjax/mcmc/integrators.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index dd0154787..e65db3d67 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -256,7 +256,11 @@ def step(state: IntegratorState, step_size): # V T V T V sigma = jax.numpy.sqrt(inverse_mass_matrix) - uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g * sigma), state.momentum, state.logdensity_grad) + # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) + uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g * sigma), state.momentum, + state.logdensity_grad) + # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) + xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u * sigma), state.position, uu) uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g * sigma), uu, gg) xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u * sigma), xx, uu) @@ -290,6 +294,7 @@ def update_momentum_mclmc(step_size, u, g): g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) e = g / g_norm ue = jax.numpy.dot(u, e) + # jax.debug.print("🤯 {x} inside momentum update 🤯", x=(ue)) dim = u.shape[0] delta = step_size * g_norm / (dim - 1) zeta = jax.numpy.exp(-delta) From d84a23d5d942928750238b91a2f294264f97bd2b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Nov 2023 22:26:13 +0100 Subject: [PATCH 21/78] remove inv mass matrix --- blackjax/mcmc/integrators.py | 16 +++++++--------- blackjax/mcmc/mclmc.py | 14 ++++---------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index e65db3d67..482e037e6 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -247,27 +247,25 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return one_step -def minimal_norm(T, V, inverse_mass_matrix): +def minimal_norm(T, V): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator - dim = inverse_mass_matrix.shape[0] def step(state: IntegratorState, step_size): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" # V T V T V - sigma = jax.numpy.sqrt(inverse_mass_matrix) # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) - uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g * sigma), state.momentum, + uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g), state.momentum, state.logdensity_grad) # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) - xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u * sigma), state.position, uu) - uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g * sigma), uu, gg) - xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u * sigma), xx, uu) - uu, r3 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g * sigma), uu, gg) + xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u), state.position, uu) + uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g), uu, gg) + xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u), xx, uu) + uu, r3 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g), uu, gg) # kinetic energy change - kinetic_change = (r1 + r2 + r3) * (dim - 1) + kinetic_change = (r1 + r2 + r3) * (uu.shape[0] - 1) return xx, uu, ll, gg, kinetic_change diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 8c6e376f8..30cd0f1af 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -54,7 +54,7 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) -def build_kernel(grad_logp, integrator, transform, inverse_mass_matrix): +def build_kernel(grad_logp, integrator, transform): """Build a HMC kernel. @@ -68,7 +68,6 @@ def build_kernel(grad_logp, integrator, transform, inverse_mass_matrix): the momentum decoherence rate step_size step size of the integrator - inverse mass matrix Returns ------- @@ -77,7 +76,7 @@ def build_kernel(grad_logp, integrator, transform, inverse_mass_matrix): information about the transition. """ - step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc, inverse_mass_matrix=inverse_mass_matrix) + step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc) def kernel(rng_key: PRNGKey, state: MCLMCState, L : float, step_size : float) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, step_size) @@ -121,7 +120,6 @@ class mclmc: transform=lambda x: x, L=L, step_size=step_size - inverse mass matrix=inverse_mass_matrix ) state = mclmc.init(position) new_state, info = mclmc.step(rng_key, state) @@ -138,14 +136,11 @@ class mclmc: logdensity_fn The log-density function we wish to draw samples from. transform - The value to use for the inverse mass matrix when drawing a value for - the momentum and computing the kinetic energy. + A function to perform on the samples drawn from the target distribution L the momentum decoherence rate step_size step size of the integrator - inverse mass matrix - Paramters integrator an integrator. We recommend using the default here. @@ -163,12 +158,11 @@ def __new__( # type: ignore[misc] transform: Callable, L, step_size, - inverse_mass_matrix, integrator=integrators.minimal_norm, ) -> SamplingAlgorithm: grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, integrator, transform, inverse_mass_matrix) + kernel = cls.build_kernel(grad_logp, integrator, transform) def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) From de1e5cf262b80c626e346225883cbb33a935c49b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Nov 2023 23:02:21 +0100 Subject: [PATCH 22/78] almost correct --- blackjax/mcmc/integrators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 482e037e6..2125aaf69 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -254,10 +254,10 @@ def step(state: IntegratorState, step_size): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" # V T V T V - # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) + jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g), state.momentum, state.logdensity_grad) - # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) + jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u), state.position, uu) uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g), uu, gg) From 263ab3acfdb173c3dddeeac188a502316d874141 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Nov 2023 12:26:22 +0100 Subject: [PATCH 23/78] almost correct --- explore.py | 328 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 explore.py diff --git a/explore.py b/explore.py new file mode 100644 index 000000000..eed473902 --- /dev/null +++ b/explore.py @@ -0,0 +1,328 @@ +from typing import NamedTuple +from chex import Array +import jax +import jax.numpy as jnp +from scipy.fftpack import next_fast_len # type: ignore + +import blackjax +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import minimal_norm +from blackjax.mcmc.mclmc import MCLMCState, build_kernel +# from blackjax.diagnostics import effective_sample_size +from blackjax.types import PRNGKey + +class Parameters(NamedTuple): + """Tunable parameters""" + + L: float + step_size: float + + +def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + +def run_sampling_algorithm( + sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key +): + # keys = jax.random.split(rng_key, num_steps) + keys = jnp.array([jax.random.PRNGKey(0)]*num_steps) + state = sampling_algorithm.init(initial_val) + print("\n\n", state.position, "\n\n") + print("\n\n", state.momentum, "\n\n") + _, info = jax.lax.scan( + lambda s, k: (sampling_algorithm.step(k, s)), state, keys + ) + return info + + +key = jax.random.PRNGKey(0) +main_key, tune_key = jax.random.split(key) + + +def ess_corr(x): + """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html + shape(x) = (num_samples, d)""" + + input_array = jnp.array( + [ + x, + ] + ) + + num_chains = 1 # input_array.shape[0] + num_samples = input_array.shape[1] + + mean_across_chain = input_array.mean(axis=1, keepdims=True) + # Compute autocovariance estimates for every lag for the input array using FFT. + centered_array = input_array - mean_across_chain + m = next_fast_len(2 * num_samples) + ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) + ifft_ary *= jnp.conjugate(ifft_ary) + autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) + autocov_value = ( + jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples + ) + mean_autocov_var = autocov_value.mean(0, keepdims=True) + mean_var0 = ( + jnp.take(mean_autocov_var, jnp.array([0]), axis=1) + * num_samples + / (num_samples - 1.0) + ) + weighted_var = mean_var0 * (num_samples - 1.0) / num_samples + weighted_var = jax.lax.cond( + num_chains > 1, + lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), + lambda _: weighted_var, + operand=None, + ) + + # Geyer's initial positive sequence + num_samples_even = num_samples - num_samples % 2 + mean_autocov_var_tp1 = jnp.take( + mean_autocov_var, jnp.arange(1, num_samples_even), axis=1 + ) + rho_hat = jnp.concatenate( + [ + jnp.ones_like(mean_var0), + 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, + ], + axis=1, + ) + + rho_hat = jnp.moveaxis(rho_hat, 1, 0) + rho_hat_even = rho_hat[0::2] + rho_hat_odd = rho_hat[1::2] + + mask0 = (rho_hat_even + rho_hat_odd) > 0.0 + carry_cond = jnp.ones_like(mask0[0]) + max_t = jnp.zeros_like(mask0[0], dtype=int) + + def positive_sequence_body_fn(state, mask_t): + t, carry_cond, max_t = state + next_mask = carry_cond & mask_t + next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) + return (t + 1, next_mask, next_max_t), next_mask + + (*_, max_t_next), mask = jax.lax.scan( + positive_sequence_body_fn, (0, carry_cond, max_t), mask0 + ) + indices = jnp.indices(max_t_next.shape) + indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) + rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) + # improve estimation + mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) + rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) + + # Geyer's initial monotone sequence + def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): + update_mask = rho_hat_sum_t > rho_hat_sum_tm1 + next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) + return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) + + rho_hat_sum = rho_hat_even + rho_hat_odd + _, (update_mask, update_value) = jax.lax.scan( + monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum + ) + + rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) + rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) + + # compute effective sample size + ess_raw = num_chains * num_samples + tau_hat = ( + -1.0 + + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) + - rho_hat_even_final[indices] + ) + + tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) + ess = ess_raw / tau_hat + + ### my part (combine all dimensions): ### + neff = ess.squeeze() / num_samples + return 1.0 / jnp.average(1 / neff) + + +# ? +# tuning() +num_steps = 100 +initial_params = Parameters(1.9913111, 0.6458658) + +def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + + nonans = jnp.all(jnp.isfinite(xx)) + + return nonans, *jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (xx, uu, ll, gg, eps_max, dK), (x, u, l, g, eps * 0.8, 0.)) + + +def dynamics_adaptive(dynamics, state, L, sigma): + """One step of the dynamics with the adaptive stepsize""" + + x, u, l, g, E, Feps, Weps, eps_max, key = state + + eps = jnp.power(Feps/Weps, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. + eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences + + # grad_logp = jax.value_and_grad(logdensity_fn) + + + # print(L_given, eps, "\n\n\n") + # m = build_kernel(grad_logp, minimal_norm, lambda x:x, L_given, eps, sigma) + # dynamics = lambda x,u,l,g, key : m(jax.random.PRNGKey(0), MCLMCState(x,u,l,g)) + + # dynamics + + # xx, uu, ll, gg, kinetic_change, key = dynamics(x, u, g, key, L, eps, sigma) + # jax.debug.print("🤯 {x} x 🤯", x=(x,u,l,g, E, Feps, Weps, eps_max)) + jax.debug.print("🤯 {x} L eps 🤯", x=(L, eps, sigma)) + jax.debug.print("🤯 {x} x u 🤯", x=(x,u, g)) + state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, -l, -g), L=L, step_size=eps) + + xx, uu, ll, gg = state + ll, gg = -ll, -gg + # jax.debug.print("🤯 {x} xx uu 🤯", x=(xx,uu)) + kinetic_change = info.kinetic_change + + varEwanted = 5e-4 + sigma_xi= 1.5 + neff = 150 # effective number of steps used to determine the stepsize in the adaptive step + gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step + + + # step updating + # jax.debug.print("🤯 {x} L eps 🤯", x=(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change)) + success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change) + + + DE = info.dE # energy difference + # jax.debug.print("🤯 {x} DE 🤯", x=(DE, kinetic_change)) + EE = E + DE # energy + # Warning: var = 0 if there were nans, but we will give it a very small weight + xi = ((DE ** 2) / (xx.shape[0] * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi + w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. + Feps = gamma * Feps + w * (xi/jnp.power(eps, 6.0)) # Kalman update the linear combinations + Weps = gamma * Weps + w + + return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success + +def tune12(x, u, l, g, random_key, L_given, eps, sigma_given, num_steps1, num_steps2): + """cheap hyperparameter tuning""" + + # mclmc = blackjax.mclmc( + # logdensity_fn=logdensity_fn, transform=lambda x: x, L=params.L, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix + # ) + + + sigma = sigma_given + gr = jax.value_and_grad(logdensity_fn) + dynamics = build_kernel(grad_logp=gr, + integrator=minimal_norm, transform=lambda x:x) + + def step(state, outer_weight): + """one adaptive step of the dynamics""" + # x,u,l,g = state + # E, Feps, Weps, eps_max = 1.0,1.0,1.0,1.0 + x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(dynamics, state[0], L, sigma) + W, F1, F2 = state[1] + w = outer_weight * eps + zero_prevention = 1-outer_weight + F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter + F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter + W += w + + return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps + + L = L_given + + # we use the last num_steps2 to compute the diagonal preconditioner + outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + #initial state + state = ((x, u, l, g, 0., jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), (0., jnp.zeros(len(x)), jnp.zeros(len(x)))) + # run the steps + state, eps = jax.lax.scan(step, init=state, xs= outer_weights, length= num_steps1 + num_steps2) + # determine L + if num_steps2 != 0.: + F1, F2 = state[1][1], state[1][2] + variances = F2 - jnp.square(F1) + sigma2 = jnp.average(variances) + + L = jnp.sqrt(sigma2 * x.shape[0]) + + xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state + return L, eps[-1], sigma, xx, uu, ll, gg, key #return the tuned hyperparameters and the final state + +def tune3(x, u, l, g, rng_key, L, eps, sigma, num_steps): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + print(L, eps, sigma, x,u, "initial params") + + + + gr = jax.value_and_grad(logdensity_fn) + kernel = build_kernel(grad_logp=gr, + integrator=minimal_norm, transform=lambda x:x) + + keys = jnp.array([jax.random.PRNGKey(0)]*num_steps) + + state, info = jax.lax.scan( + lambda s, k: (kernel(k, s, L, eps)), MCLMCState(x,u,-l,-g), keys + ) + + # state, info = kernel(jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps) + # xx,uu,ll,gg = state + X = info.transformed_x + + # sample_full(num_steps, x, u, l, g, random_key, L, eps, sigma) + ESS = ess_corr(X) + Lfactor = 0.4 + Lnew = Lfactor * eps / ESS # = 0.4 * correlation length + print(ESS, "ess", X, Lfactor, eps) + return Lnew, state + +def tune(num_steps: int, params: Parameters, rng_key: PRNGKey) -> Parameters: + + + + x, u, l, g, key, L, eps, sigma, steps1, steps2 = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), jax.random.PRNGKey(0), 1.4142135, 0.56568545, jnp.array([1., 1.]), 10, 10 + + L, eps, sigma, x, u, l, g, key = tune12(x, u, l, g, key, L, eps, sigma, steps1, steps2) + print("L, eps post tune12", L, eps) + + + + + + steps3 = int(num_steps * 0.1) + L, state = tune3(x, u, l, g, key, L, eps, sigma, steps3) + print("L post tune3", L) + return L, eps, state + + +L, eps, state = (tune(num_steps=100, params=initial_params, rng_key=tune_key)) +print("L, eps post tuning", L, eps) +raise Exception +mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + transform=lambda x: x, + # L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0] + step_size=0.56568545, L=1.4142135, +) + + +out = run_sampling_algorithm( + sampling_algorithm=mclmc, + num_steps=100, + initial_val=jnp.array([0.1, 0.1]), + rng_key=main_key, +) + +print(jnp.mean(out.transformed_x, axis=0)) + +# print(logdensity_fn(jnp.array([0.1, 0.1]))) +# print(out) + +assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [-1.2130139, 1.5367734]) + + From 777213d6463529fa1f56b4e541ad893c010c50cb Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Nov 2023 12:30:44 +0100 Subject: [PATCH 24/78] move tuning to adaptation --- blackjax/adaptation/step_size.py | 274 +++++++++++++++++++++++++++++++ explore.py | 267 +----------------------------- 2 files changed, 278 insertions(+), 263 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2d6b0182f..2edfa6184 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -16,8 +16,10 @@ import jax import jax.numpy as jnp +from scipy.fft import next_fast_len from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.mclmc import MCLMCState from blackjax.optimizers.dual_averaging import dual_averaging from blackjax.types import PRNGKey @@ -257,3 +259,275 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size + + + + + + + + + + +#### mclmc + +class Parameters(NamedTuple): + """Tunable parameters""" + + L: float + step_size: float + + +def ess_corr(x): + """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html + shape(x) = (num_samples, d)""" + + input_array = jnp.array( + [ + x, + ] + ) + + num_chains = 1 # input_array.shape[0] + num_samples = input_array.shape[1] + + mean_across_chain = input_array.mean(axis=1, keepdims=True) + # Compute autocovariance estimates for every lag for the input array using FFT. + centered_array = input_array - mean_across_chain + m = next_fast_len(2 * num_samples) + ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) + ifft_ary *= jnp.conjugate(ifft_ary) + autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) + autocov_value = ( + jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples + ) + mean_autocov_var = autocov_value.mean(0, keepdims=True) + mean_var0 = ( + jnp.take(mean_autocov_var, jnp.array([0]), axis=1) + * num_samples + / (num_samples - 1.0) + ) + weighted_var = mean_var0 * (num_samples - 1.0) / num_samples + weighted_var = jax.lax.cond( + num_chains > 1, + lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), + lambda _: weighted_var, + operand=None, + ) + + # Geyer's initial positive sequence + num_samples_even = num_samples - num_samples % 2 + mean_autocov_var_tp1 = jnp.take( + mean_autocov_var, jnp.arange(1, num_samples_even), axis=1 + ) + rho_hat = jnp.concatenate( + [ + jnp.ones_like(mean_var0), + 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, + ], + axis=1, + ) + + rho_hat = jnp.moveaxis(rho_hat, 1, 0) + rho_hat_even = rho_hat[0::2] + rho_hat_odd = rho_hat[1::2] + + mask0 = (rho_hat_even + rho_hat_odd) > 0.0 + carry_cond = jnp.ones_like(mask0[0]) + max_t = jnp.zeros_like(mask0[0], dtype=int) + + def positive_sequence_body_fn(state, mask_t): + t, carry_cond, max_t = state + next_mask = carry_cond & mask_t + next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) + return (t + 1, next_mask, next_max_t), next_mask + + (*_, max_t_next), mask = jax.lax.scan( + positive_sequence_body_fn, (0, carry_cond, max_t), mask0 + ) + indices = jnp.indices(max_t_next.shape) + indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) + rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) + # improve estimation + mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) + rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) + + # Geyer's initial monotone sequence + def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): + update_mask = rho_hat_sum_t > rho_hat_sum_tm1 + next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) + return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) + + rho_hat_sum = rho_hat_even + rho_hat_odd + _, (update_mask, update_value) = jax.lax.scan( + monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum + ) + + rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) + rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) + + # compute effective sample size + ess_raw = num_chains * num_samples + tau_hat = ( + -1.0 + + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) + - rho_hat_even_final[indices] + ) + + tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) + ess = ess_raw / tau_hat + + ### my part (combine all dimensions): ### + neff = ess.squeeze() / num_samples + return 1.0 / jnp.average(1 / neff) + + +# ? +# tuning() +num_steps = 100 +initial_params = Parameters(1.9913111, 0.6458658) + +def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + + nonans = jnp.all(jnp.isfinite(xx)) + + return nonans, *jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (xx, uu, ll, gg, eps_max, dK), (x, u, l, g, eps * 0.8, 0.)) + + +def dynamics_adaptive(dynamics, state, L, sigma): + """One step of the dynamics with the adaptive stepsize""" + + x, u, l, g, E, Feps, Weps, eps_max, key = state + + eps = jnp.power(Feps/Weps, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. + eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences + + # grad_logp = jax.value_and_grad(logdensity_fn) + + + # print(L_given, eps, "\n\n\n") + # m = build_kernel(grad_logp, minimal_norm, lambda x:x, L_given, eps, sigma) + # dynamics = lambda x,u,l,g, key : m(jax.random.PRNGKey(0), MCLMCState(x,u,l,g)) + + # dynamics + + # xx, uu, ll, gg, kinetic_change, key = dynamics(x, u, g, key, L, eps, sigma) + # jax.debug.print("🤯 {x} x 🤯", x=(x,u,l,g, E, Feps, Weps, eps_max)) + jax.debug.print("🤯 {x} L eps 🤯", x=(L, eps, sigma)) + jax.debug.print("🤯 {x} x u 🤯", x=(x,u, g)) + state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, -l, -g), L=L, step_size=eps) + + xx, uu, ll, gg = state + ll, gg = -ll, -gg + # jax.debug.print("🤯 {x} xx uu 🤯", x=(xx,uu)) + kinetic_change = info.kinetic_change + + varEwanted = 5e-4 + sigma_xi= 1.5 + neff = 150 # effective number of steps used to determine the stepsize in the adaptive step + gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step + + + # step updating + # jax.debug.print("🤯 {x} L eps 🤯", x=(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change)) + success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change) + + + DE = info.dE # energy difference + # jax.debug.print("🤯 {x} DE 🤯", x=(DE, kinetic_change)) + EE = E + DE # energy + # Warning: var = 0 if there were nans, but we will give it a very small weight + xi = ((DE ** 2) / (xx.shape[0] * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi + w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. + Feps = gamma * Feps + w * (xi/jnp.power(eps, 6.0)) # Kalman update the linear combinations + Weps = gamma * Weps + w + + return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success + +def tune12(kernel,x, u, l, g, random_key, L_given, eps, sigma_given, num_steps1, num_steps2): + """cheap hyperparameter tuning""" + + # mclmc = blackjax.mclmc( + # logdensity_fn=logdensity_fn, transform=lambda x: x, L=params.L, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix + # ) + + + sigma = sigma_given + + def step(state, outer_weight): + """one adaptive step of the dynamics""" + # x,u,l,g = state + # E, Feps, Weps, eps_max = 1.0,1.0,1.0,1.0 + x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(kernel, state[0], L, sigma) + W, F1, F2 = state[1] + w = outer_weight * eps + zero_prevention = 1-outer_weight + F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter + F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter + W += w + + return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps + + L = L_given + + # we use the last num_steps2 to compute the diagonal preconditioner + outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + #initial state + state = ((x, u, l, g, 0., jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), (0., jnp.zeros(len(x)), jnp.zeros(len(x)))) + # run the steps + state, eps = jax.lax.scan(step, init=state, xs= outer_weights, length= num_steps1 + num_steps2) + # determine L + if num_steps2 != 0.: + F1, F2 = state[1][1], state[1][2] + variances = F2 - jnp.square(F1) + sigma2 = jnp.average(variances) + + L = jnp.sqrt(sigma2 * x.shape[0]) + + xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state + return L, eps[-1], sigma, xx, uu, ll, gg, key #return the tuned hyperparameters and the final state + +def tune3(kernel, x, u, l, g, rng_key, L, eps, sigma, num_steps): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + print(L, eps, sigma, x,u, "initial params") + + + + + + keys = jnp.array([jax.random.PRNGKey(0)]*num_steps) + + state, info = jax.lax.scan( + lambda s, k: (kernel(k, s, L, eps)), MCLMCState(x,u,-l,-g), keys + ) + + # state, info = kernel(jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps) + # xx,uu,ll,gg = state + X = info.transformed_x + + # sample_full(num_steps, x, u, l, g, random_key, L, eps, sigma) + ESS = ess_corr(X) + Lfactor = 0.4 + Lnew = Lfactor * eps / ESS # = 0.4 * correlation length + print(ESS, "ess", X, Lfactor, eps) + return Lnew, state + +def tune(kernel, num_steps: int, rng_key: PRNGKey) -> Parameters: + + + + x, u, l, g, key, L, eps, sigma, steps1, steps2 = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), jax.random.PRNGKey(0), 1.4142135, 0.56568545, jnp.array([1., 1.]), 10, 10 + + L, eps, sigma, x, u, l, g, key = tune12(kernel, x, u, l, g, key, L, eps, sigma, steps1, steps2) + print("L, eps post tune12", L, eps) + + + + + + steps3 = int(num_steps * 0.1) + L, state = tune3(kernel, x, u, l, g, key, L, eps, sigma, steps3) + print("L post tune3", L) + return L, eps, state \ No newline at end of file diff --git a/explore.py b/explore.py index eed473902..14309c886 100644 --- a/explore.py +++ b/explore.py @@ -5,17 +5,13 @@ from scipy.fftpack import next_fast_len # type: ignore import blackjax +from blackjax.adaptation.step_size import tune from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import minimal_norm from blackjax.mcmc.mclmc import MCLMCState, build_kernel # from blackjax.diagnostics import effective_sample_size from blackjax.types import PRNGKey -class Parameters(NamedTuple): - """Tunable parameters""" - - L: float - step_size: float def logdensity_fn(x): @@ -40,267 +36,12 @@ def run_sampling_algorithm( main_key, tune_key = jax.random.split(key) -def ess_corr(x): - """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html - shape(x) = (num_samples, d)""" - - input_array = jnp.array( - [ - x, - ] - ) - - num_chains = 1 # input_array.shape[0] - num_samples = input_array.shape[1] - - mean_across_chain = input_array.mean(axis=1, keepdims=True) - # Compute autocovariance estimates for every lag for the input array using FFT. - centered_array = input_array - mean_across_chain - m = next_fast_len(2 * num_samples) - ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) - ifft_ary *= jnp.conjugate(ifft_ary) - autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) - autocov_value = ( - jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples - ) - mean_autocov_var = autocov_value.mean(0, keepdims=True) - mean_var0 = ( - jnp.take(mean_autocov_var, jnp.array([0]), axis=1) - * num_samples - / (num_samples - 1.0) - ) - weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - weighted_var = jax.lax.cond( - num_chains > 1, - lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), - lambda _: weighted_var, - operand=None, - ) - - # Geyer's initial positive sequence - num_samples_even = num_samples - num_samples % 2 - mean_autocov_var_tp1 = jnp.take( - mean_autocov_var, jnp.arange(1, num_samples_even), axis=1 - ) - rho_hat = jnp.concatenate( - [ - jnp.ones_like(mean_var0), - 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, - ], - axis=1, - ) - - rho_hat = jnp.moveaxis(rho_hat, 1, 0) - rho_hat_even = rho_hat[0::2] - rho_hat_odd = rho_hat[1::2] - - mask0 = (rho_hat_even + rho_hat_odd) > 0.0 - carry_cond = jnp.ones_like(mask0[0]) - max_t = jnp.zeros_like(mask0[0], dtype=int) - - def positive_sequence_body_fn(state, mask_t): - t, carry_cond, max_t = state - next_mask = carry_cond & mask_t - next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) - return (t + 1, next_mask, next_max_t), next_mask - - (*_, max_t_next), mask = jax.lax.scan( - positive_sequence_body_fn, (0, carry_cond, max_t), mask0 - ) - indices = jnp.indices(max_t_next.shape) - indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) - rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) - # improve estimation - mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) - rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) - - # Geyer's initial monotone sequence - def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): - update_mask = rho_hat_sum_t > rho_hat_sum_tm1 - next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) - return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) - - rho_hat_sum = rho_hat_even + rho_hat_odd - _, (update_mask, update_value) = jax.lax.scan( - monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum - ) - - rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) - rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) - - # compute effective sample size - ess_raw = num_chains * num_samples - tau_hat = ( - -1.0 - + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) - - rho_hat_even_final[indices] - ) - - tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) - ess = ess_raw / tau_hat - - ### my part (combine all dimensions): ### - neff = ess.squeeze() / num_samples - return 1.0 / jnp.average(1 / neff) - - -# ? -# tuning() -num_steps = 100 -initial_params = Parameters(1.9913111, 0.6458658) - -def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): - """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" - - nonans = jnp.all(jnp.isfinite(xx)) - - return nonans, *jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (xx, uu, ll, gg, eps_max, dK), (x, u, l, g, eps * 0.8, 0.)) - - -def dynamics_adaptive(dynamics, state, L, sigma): - """One step of the dynamics with the adaptive stepsize""" - - x, u, l, g, E, Feps, Weps, eps_max, key = state - - eps = jnp.power(Feps/Weps, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. - eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - - # grad_logp = jax.value_and_grad(logdensity_fn) - - - # print(L_given, eps, "\n\n\n") - # m = build_kernel(grad_logp, minimal_norm, lambda x:x, L_given, eps, sigma) - # dynamics = lambda x,u,l,g, key : m(jax.random.PRNGKey(0), MCLMCState(x,u,l,g)) - - # dynamics - - # xx, uu, ll, gg, kinetic_change, key = dynamics(x, u, g, key, L, eps, sigma) - # jax.debug.print("🤯 {x} x 🤯", x=(x,u,l,g, E, Feps, Weps, eps_max)) - jax.debug.print("🤯 {x} L eps 🤯", x=(L, eps, sigma)) - jax.debug.print("🤯 {x} x u 🤯", x=(x,u, g)) - state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, -l, -g), L=L, step_size=eps) - - xx, uu, ll, gg = state - ll, gg = -ll, -gg - # jax.debug.print("🤯 {x} xx uu 🤯", x=(xx,uu)) - kinetic_change = info.kinetic_change - - varEwanted = 5e-4 - sigma_xi= 1.5 - neff = 150 # effective number of steps used to determine the stepsize in the adaptive step - gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step - - - # step updating - # jax.debug.print("🤯 {x} L eps 🤯", x=(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change)) - success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change) - - - DE = info.dE # energy difference - # jax.debug.print("🤯 {x} DE 🤯", x=(DE, kinetic_change)) - EE = E + DE # energy - # Warning: var = 0 if there were nans, but we will give it a very small weight - xi = ((DE ** 2) / (xx.shape[0] * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi - w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. - Feps = gamma * Feps + w * (xi/jnp.power(eps, 6.0)) # Kalman update the linear combinations - Weps = gamma * Weps + w - - return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success - -def tune12(x, u, l, g, random_key, L_given, eps, sigma_given, num_steps1, num_steps2): - """cheap hyperparameter tuning""" - - # mclmc = blackjax.mclmc( - # logdensity_fn=logdensity_fn, transform=lambda x: x, L=params.L, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix - # ) - - - sigma = sigma_given - gr = jax.value_and_grad(logdensity_fn) - dynamics = build_kernel(grad_logp=gr, - integrator=minimal_norm, transform=lambda x:x) - - def step(state, outer_weight): - """one adaptive step of the dynamics""" - # x,u,l,g = state - # E, Feps, Weps, eps_max = 1.0,1.0,1.0,1.0 - x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(dynamics, state[0], L, sigma) - W, F1, F2 = state[1] - w = outer_weight * eps - zero_prevention = 1-outer_weight - F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter - F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter - W += w - - return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps - - L = L_given - - # we use the last num_steps2 to compute the diagonal preconditioner - outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - - #initial state - state = ((x, u, l, g, 0., jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), (0., jnp.zeros(len(x)), jnp.zeros(len(x)))) - # run the steps - state, eps = jax.lax.scan(step, init=state, xs= outer_weights, length= num_steps1 + num_steps2) - # determine L - if num_steps2 != 0.: - F1, F2 = state[1][1], state[1][2] - variances = F2 - jnp.square(F1) - sigma2 = jnp.average(variances) - - L = jnp.sqrt(sigma2 * x.shape[0]) - - xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state - return L, eps[-1], sigma, xx, uu, ll, gg, key #return the tuned hyperparameters and the final state - -def tune3(x, u, l, g, rng_key, L, eps, sigma, num_steps): - """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - print(L, eps, sigma, x,u, "initial params") - - - - gr = jax.value_and_grad(logdensity_fn) - kernel = build_kernel(grad_logp=gr, +gr = jax.value_and_grad(logdensity_fn) +kernel = build_kernel(grad_logp=gr, integrator=minimal_norm, transform=lambda x:x) - - keys = jnp.array([jax.random.PRNGKey(0)]*num_steps) - - state, info = jax.lax.scan( - lambda s, k: (kernel(k, s, L, eps)), MCLMCState(x,u,-l,-g), keys - ) - - # state, info = kernel(jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps) - # xx,uu,ll,gg = state - X = info.transformed_x - - # sample_full(num_steps, x, u, l, g, random_key, L, eps, sigma) - ESS = ess_corr(X) - Lfactor = 0.4 - Lnew = Lfactor * eps / ESS # = 0.4 * correlation length - print(ESS, "ess", X, Lfactor, eps) - return Lnew, state - -def tune(num_steps: int, params: Parameters, rng_key: PRNGKey) -> Parameters: - - - - x, u, l, g, key, L, eps, sigma, steps1, steps2 = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), jax.random.PRNGKey(0), 1.4142135, 0.56568545, jnp.array([1., 1.]), 10, 10 - - L, eps, sigma, x, u, l, g, key = tune12(x, u, l, g, key, L, eps, sigma, steps1, steps2) - print("L, eps post tune12", L, eps) - - - - - - steps3 = int(num_steps * 0.1) - L, state = tune3(x, u, l, g, key, L, eps, sigma, steps3) - print("L post tune3", L) - return L, eps, state -L, eps, state = (tune(num_steps=100, params=initial_params, rng_key=tune_key)) +L, eps, state = (tune(kernel, num_steps=100, rng_key=tune_key)) print("L, eps post tuning", L, eps) raise Exception mclmc = blackjax.mcmc.mclmc.mclmc( From e75274afd90c26525e726d743394eeacb3b0b283 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Nov 2023 12:43:23 +0100 Subject: [PATCH 25/78] tuning works in this commit --- blackjax/mcmc/integrators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 2125aaf69..482e037e6 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -254,10 +254,10 @@ def step(state: IntegratorState, step_size): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" # V T V T V - jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) + # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g), state.momentum, state.logdensity_grad) - jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) + # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u), state.position, uu) uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g), uu, gg) From 8a89f13b5e99570fcb31107b360c69d8e53fecc0 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Nov 2023 12:50:28 +0100 Subject: [PATCH 26/78] clean up 1 --- blackjax/adaptation/step_size.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2edfa6184..5d4e3ad27 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -516,18 +516,15 @@ def tune3(kernel, x, u, l, g, rng_key, L, eps, sigma, num_steps): def tune(kernel, num_steps: int, rng_key: PRNGKey) -> Parameters: + num_tune_step_ratio_1 = 0.1 + num_tune_step_ratio_2 = 0.1 - x, u, l, g, key, L, eps, sigma, steps1, steps2 = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), jax.random.PRNGKey(0), 1.4142135, 0.56568545, jnp.array([1., 1.]), 10, 10 + x, u, l, g, L, eps, sigma = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), 1.4142135, 0.56568545, jnp.array([1., 1.]) - L, eps, sigma, x, u, l, g, key = tune12(kernel, x, u, l, g, key, L, eps, sigma, steps1, steps2) + L, eps, sigma, x, u, l, g, key = tune12(kernel, x, u, l, g, rng_key, L, eps, sigma, int(num_steps * num_tune_step_ratio_1), int(num_steps * num_tune_step_ratio_1)) print("L, eps post tune12", L, eps) - - - - - steps3 = int(num_steps * 0.1) - L, state = tune3(kernel, x, u, l, g, key, L, eps, sigma, steps3) + L, state = tune3(kernel, x, u, l, g, key, L, eps, sigma, int(num_steps * num_tune_step_ratio_2)) print("L post tune3", L) return L, eps, state \ No newline at end of file From 49b3bec2a49d201e9114d618503d933614c7d5ab Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Nov 2023 12:52:15 +0100 Subject: [PATCH 27/78] remove sigma from tuning --- blackjax/adaptation/step_size.py | 40 ++++++++------------------------ 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 5d4e3ad27..dcd1639ec 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -395,7 +395,7 @@ def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): return nonans, *jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (xx, uu, ll, gg, eps_max, dK), (x, u, l, g, eps * 0.8, 0.)) -def dynamics_adaptive(dynamics, state, L, sigma): +def dynamics_adaptive(dynamics, state, L): """One step of the dynamics with the adaptive stepsize""" x, u, l, g, E, Feps, Weps, eps_max, key = state @@ -403,19 +403,6 @@ def dynamics_adaptive(dynamics, state, L, sigma): eps = jnp.power(Feps/Weps, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - # grad_logp = jax.value_and_grad(logdensity_fn) - - - # print(L_given, eps, "\n\n\n") - # m = build_kernel(grad_logp, minimal_norm, lambda x:x, L_given, eps, sigma) - # dynamics = lambda x,u,l,g, key : m(jax.random.PRNGKey(0), MCLMCState(x,u,l,g)) - - # dynamics - - # xx, uu, ll, gg, kinetic_change, key = dynamics(x, u, g, key, L, eps, sigma) - # jax.debug.print("🤯 {x} x 🤯", x=(x,u,l,g, E, Feps, Weps, eps_max)) - jax.debug.print("🤯 {x} L eps 🤯", x=(L, eps, sigma)) - jax.debug.print("🤯 {x} x u 🤯", x=(x,u, g)) state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, -l, -g), L=L, step_size=eps) xx, uu, ll, gg = state @@ -445,21 +432,18 @@ def dynamics_adaptive(dynamics, state, L, sigma): return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success -def tune12(kernel,x, u, l, g, random_key, L_given, eps, sigma_given, num_steps1, num_steps2): +def tune12(kernel,x, u, l, g, random_key, L_given, eps, num_steps1, num_steps2): """cheap hyperparameter tuning""" # mclmc = blackjax.mclmc( # logdensity_fn=logdensity_fn, transform=lambda x: x, L=params.L, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix # ) - - - sigma = sigma_given - + def step(state, outer_weight): """one adaptive step of the dynamics""" # x,u,l,g = state # E, Feps, Weps, eps_max = 1.0,1.0,1.0,1.0 - x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(kernel, state[0], L, sigma) + x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(kernel, state[0], L) W, F1, F2 = state[1] w = outer_weight * eps zero_prevention = 1-outer_weight @@ -487,11 +471,10 @@ def step(state, outer_weight): L = jnp.sqrt(sigma2 * x.shape[0]) xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state - return L, eps[-1], sigma, xx, uu, ll, gg, key #return the tuned hyperparameters and the final state + return L, eps[-1], xx, uu, ll, gg, key #return the tuned hyperparameters and the final state -def tune3(kernel, x, u, l, g, rng_key, L, eps, sigma, num_steps): +def tune3(kernel, x, u, l, g, rng_key, L, eps, num_steps): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - print(L, eps, sigma, x,u, "initial params") @@ -503,14 +486,11 @@ def tune3(kernel, x, u, l, g, rng_key, L, eps, sigma, num_steps): lambda s, k: (kernel(k, s, L, eps)), MCLMCState(x,u,-l,-g), keys ) - # state, info = kernel(jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps) - # xx,uu,ll,gg = state X = info.transformed_x - # sample_full(num_steps, x, u, l, g, random_key, L, eps, sigma) ESS = ess_corr(X) Lfactor = 0.4 - Lnew = Lfactor * eps / ESS # = 0.4 * correlation length + Lnew = Lfactor * eps / ESS print(ESS, "ess", X, Lfactor, eps) return Lnew, state @@ -520,11 +500,11 @@ def tune(kernel, num_steps: int, rng_key: PRNGKey) -> Parameters: num_tune_step_ratio_2 = 0.1 - x, u, l, g, L, eps, sigma = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), 1.4142135, 0.56568545, jnp.array([1., 1.]) + x, u, l, g, L, eps = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), 1.4142135, 0.56568545 - L, eps, sigma, x, u, l, g, key = tune12(kernel, x, u, l, g, rng_key, L, eps, sigma, int(num_steps * num_tune_step_ratio_1), int(num_steps * num_tune_step_ratio_1)) + L, eps, x, u, l, g, key = tune12(kernel, x, u, l, g, rng_key, L, eps, int(num_steps * num_tune_step_ratio_1), int(num_steps * num_tune_step_ratio_1)) print("L, eps post tune12", L, eps) - L, state = tune3(kernel, x, u, l, g, key, L, eps, sigma, int(num_steps * num_tune_step_ratio_2)) + L, state = tune3(kernel, x, u, l, g, key, L, eps, int(num_steps * num_tune_step_ratio_2)) print("L post tune3", L) return L, eps, state \ No newline at end of file From 81999f9e20d474295929de3088a676052e08573c Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Nov 2023 14:43:58 +0100 Subject: [PATCH 28/78] wip --- blackjax/adaptation/step_size.py | 202 ++++++++++++++++++++++++------- 1 file changed, 158 insertions(+), 44 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index dcd1639ec..a96a0be7e 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Step size adaptation""" +import math from typing import Callable, NamedTuple import jax @@ -21,7 +22,7 @@ from blackjax.mcmc.hmc import HMCState from blackjax.mcmc.mclmc import MCLMCState from blackjax.optimizers.dual_averaging import dual_averaging -from blackjax.types import PRNGKey +from blackjax.types import Array, ArrayLikeTree, PRNGKey __all__ = [ "DualAveragingAdaptationState", @@ -271,12 +272,146 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: #### mclmc -class Parameters(NamedTuple): +class MCLMCAdaptationState(NamedTuple): """Tunable parameters""" L: float step_size: float +def effective_sample_size( + input_array: ArrayLikeTree, chain_axis: int = 0, sample_axis: int = 1 +) -> Array: + """Compute estimate of the effective sample size (ess). + + Parameters + ---------- + input_array: + An array representing multiple chains of MCMC samples. The array must + contains a chain dimension and a sample dimension. + chain_axis + The axis indicating the multiple chains. Default to 0. + sample_axis + The axis indicating a single chain of MCMC samples. Default to 1. + + Returns + ------- + NDArray of the resulting statistics (ess), with the chain and sample dimensions squeezed. + + Notes + ----- + The basic ess (:math:`N_{\\mathit{eff}}`) diagnostic is computed by: + + .. math:: \\hat{N}_{\\mathit{eff}} = \\frac{MN}{\\hat{\\tau}} + + .. math:: \\hat{\\tau} = -1 + 2 \\sum_{t'=0}^K \\hat{P}_{t'} + + where :math:`M` is the number of chains, :math:`N` the number of draws, + :math:`\\hat{\\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and + :math:`K` is the last integer for which :math:`\\hat{P}_{K} = \\hat{\\rho}_{2K} + + \\hat{\\rho}_{2K+1}` is still positive :cite:p:`stan_ess,gelman1995bayesian`. + + The current implementation is similar to Stan, which uses Geyer's initial monotone sequence + criterion :cite:p:`geyer1992practical,geyer2011introduction`. + + """ + input_shape = input_array.shape + sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis + num_chains = input_shape[chain_axis] + num_samples = input_shape[sample_axis] + assert ( + num_chains > 1 + ), "effective_sample_size as implemented only works for two or more chains." + + mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True) + print("mean 1", mean_across_chain) + # Compute autocovariance estimates for every lag for the input array using FFT. + centered_array = input_array - mean_across_chain + m = next_fast_len(2 * num_samples) + ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=sample_axis) + ifft_ary *= jnp.conjugate(ifft_ary) + autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=sample_axis) + autocov_value = ( + jnp.take(autocov_value, jnp.arange(num_samples), axis=sample_axis) / num_samples + ) + mean_autocov_var = autocov_value.mean(chain_axis, keepdims=True) + mean_var0 = ( + jnp.take(mean_autocov_var, jnp.array([0]), axis=sample_axis) + * num_samples + / (num_samples - 1.0) + ) + weighted_var = mean_var0 * (num_samples - 1.0) / num_samples + weighted_var = jax.lax.cond( + num_chains > 1, + lambda _: weighted_var + + mean_across_chain.var(axis=chain_axis, ddof=1, keepdims=True), + lambda _: weighted_var, + operand=None, + ) + + # Geyer's initial positive sequence + num_samples_even = num_samples - num_samples % 2 + mean_autocov_var_tp1 = jnp.take( + mean_autocov_var, jnp.arange(1, num_samples_even), axis=sample_axis + ) + rho_hat = jnp.concatenate( + [ + jnp.ones_like(mean_var0), + 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, + ], + axis=sample_axis, + ) + + rho_hat = jnp.moveaxis(rho_hat, sample_axis, 0) + rho_hat_even = rho_hat[0::2] + rho_hat_odd = rho_hat[1::2] + + mask0 = (rho_hat_even + rho_hat_odd) > 0.0 + carry_cond = jnp.ones_like(mask0[0]) + max_t = jnp.zeros_like(mask0[0], dtype=int) + + def positive_sequence_body_fn(state, mask_t): + t, carry_cond, max_t = state + next_mask = carry_cond & mask_t + next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) + return (t + 1, next_mask, next_max_t), next_mask + + (*_, max_t_next), mask = jax.lax.scan( + positive_sequence_body_fn, (0, carry_cond, max_t), mask0 + ) + indices = jnp.indices(max_t_next.shape) + indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) + rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) + # improve estimation + mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) + rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) + + # Geyer's initial monotone sequence + def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): + update_mask = rho_hat_sum_t > rho_hat_sum_tm1 + next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) + return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) + + rho_hat_sum = rho_hat_even + rho_hat_odd + _, (update_mask, update_value) = jax.lax.scan( + monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum + ) + + rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) + rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) + + # compute effective sample size + ess_raw = num_chains * num_samples + tau_hat = ( + -1.0 + + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) + - rho_hat_even_final[indices] + ) + + tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) + ess = ess_raw / tau_hat + print("tau hat", ess) + + return ess.squeeze() def ess_corr(x): """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html @@ -292,6 +427,7 @@ def ess_corr(x): num_samples = input_array.shape[1] mean_across_chain = input_array.mean(axis=1, keepdims=True) + print("mean 2", mean_across_chain) # Compute autocovariance estimates for every lag for the input array using FFT. centered_array = input_array - mean_across_chain m = next_fast_len(2 * num_samples) @@ -379,22 +515,16 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): ### my part (combine all dimensions): ### neff = ess.squeeze() / num_samples + print("tau hat", ess, num_samples, neff) return 1.0 / jnp.average(1 / neff) -# ? -# tuning() -num_steps = 100 -initial_params = Parameters(1.9913111, 0.6458658) - def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" nonans = jnp.all(jnp.isfinite(xx)) - return nonans, *jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (xx, uu, ll, gg, eps_max, dK), (x, u, l, g, eps * 0.8, 0.)) - def dynamics_adaptive(dynamics, state, L): """One step of the dynamics with the adaptive stepsize""" @@ -403,11 +533,10 @@ def dynamics_adaptive(dynamics, state, L): eps = jnp.power(Feps/Weps, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, -l, -g), L=L, step_size=eps) + state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps) xx, uu, ll, gg = state - ll, gg = -ll, -gg - # jax.debug.print("🤯 {x} xx uu 🤯", x=(xx,uu)) + # ll, gg = -ll, -gg kinetic_change = info.kinetic_change varEwanted = 5e-4 @@ -417,12 +546,10 @@ def dynamics_adaptive(dynamics, state, L): # step updating - # jax.debug.print("🤯 {x} L eps 🤯", x=(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change)) success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change) DE = info.dE # energy difference - # jax.debug.print("🤯 {x} DE 🤯", x=(DE, kinetic_change)) EE = E + DE # energy # Warning: var = 0 if there were nans, but we will give it a very small weight xi = ((DE ** 2) / (xx.shape[0] * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi @@ -432,17 +559,11 @@ def dynamics_adaptive(dynamics, state, L): return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success -def tune12(kernel,x, u, l, g, random_key, L_given, eps, num_steps1, num_steps2): +def tune12(kernel,x, u, l, g, random_key, L, eps, num_steps1, num_steps2): """cheap hyperparameter tuning""" - - # mclmc = blackjax.mclmc( - # logdensity_fn=logdensity_fn, transform=lambda x: x, L=params.L, step_size=params.step_size, inverse_mass_matrix=params.inverse_mass_matrix - # ) def step(state, outer_weight): """one adaptive step of the dynamics""" - # x,u,l,g = state - # E, Feps, Weps, eps_max = 1.0,1.0,1.0,1.0 x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(kernel, state[0], L) W, F1, F2 = state[1] w = outer_weight * eps @@ -453,8 +574,6 @@ def step(state, outer_weight): return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps - L = L_given - # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) @@ -470,41 +589,36 @@ def step(state, outer_weight): L = jnp.sqrt(sigma2 * x.shape[0]) - xx, uu, ll, gg, key = state[0][0], state[0][1], state[0][2], state[0][3], state[0][-1] # the final state - return L, eps[-1], xx, uu, ll, gg, key #return the tuned hyperparameters and the final state + xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state + return L, eps[-1], MCLMCState(xx, uu, ll, gg) #return the tuned hyperparameters and the final state -def tune3(kernel, x, u, l, g, rng_key, L, eps, num_steps): +def tune3(kernel, state, rng_key, L, eps, num_steps): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - - - - keys = jnp.array([jax.random.PRNGKey(0)]*num_steps) - state, info = jax.lax.scan( - lambda s, k: (kernel(k, s, L, eps)), MCLMCState(x,u,-l,-g), keys + lambda s, k: (kernel(k, s, L, eps)), state, jax.random.split(rng_key, num_steps) ) - - X = info.transformed_x - ESS = ess_corr(X) Lfactor = 0.4 - Lnew = Lfactor * eps / ESS - print(ESS, "ess", X, Lfactor, eps) + ESS2 = effective_sample_size(info.transformed_x) + neff = ESS2.squeeze() / info.transformed_x.shape[0] + print("neff", neff, info.transformed_x.shape[0]) + ESS_alt = 1.0 / jnp.average(1 / neff) + print(ess_corr(info.transformed_x), ESS_alt, "\n\nESSse\n\n") + Lnew = Lfactor * eps / ess_corr(info.transformed_x) return Lnew, state -def tune(kernel, num_steps: int, rng_key: PRNGKey) -> Parameters: +def tune(kernel, num_steps: int, rng_key: PRNGKey, params : MCLMCAdaptationState) -> tuple[MCLMCAdaptationState, MCLMCState]: num_tune_step_ratio_1 = 0.1 num_tune_step_ratio_2 = 0.1 + x, u, l, g = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), -0.010000001, -jnp.array([0.1, 0.1]) - x, u, l, g, L, eps = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), 0.010000001, jnp.array([0.1, 0.1]), 1.4142135, 0.56568545 + tune1_key, tune2_key = jax.random.split(rng_key) - L, eps, x, u, l, g, key = tune12(kernel, x, u, l, g, rng_key, L, eps, int(num_steps * num_tune_step_ratio_1), int(num_steps * num_tune_step_ratio_1)) - print("L, eps post tune12", L, eps) + L, eps, state = tune12(kernel, x, u, l, g, tune1_key, params.L, params.step_size, int(num_steps * num_tune_step_ratio_1), int(num_steps * num_tune_step_ratio_1)) - L, state = tune3(kernel, x, u, l, g, key, L, eps, int(num_steps * num_tune_step_ratio_2)) - print("L post tune3", L) - return L, eps, state \ No newline at end of file + L, state = tune3(kernel, state, tune2_key, L, eps, int(num_steps * num_tune_step_ratio_2)) + return MCLMCAdaptationState(L, eps), state \ No newline at end of file From 8ab01f244a1afd2a362e2144ff2b3e6f75209d72 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 Nov 2023 16:23:30 +0000 Subject: [PATCH 29/78] fix linting --- blackjax/__init__.py | 2 +- blackjax/adaptation/step_size.py | 199 +++++++++++++++++++------------ blackjax/mcmc/integrators.py | 32 +++-- blackjax/mcmc/mclmc.py | 33 ++--- 4 files changed, 160 insertions(+), 106 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 691f1c92f..bd8c452e2 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -10,8 +10,8 @@ from .mcmc.ghmc import ghmc from .mcmc.hmc import dynamic_hmc, hmc from .mcmc.mala import mala -from .mcmc.mclmc import mclmc from .mcmc.marginal_latent_gaussian import mgrad_gaussian +from .mcmc.mclmc import mclmc from .mcmc.nuts import nuts from .mcmc.periodic_orbital import orbital_hmc from .mcmc.random_walk import additive_step_random_walk, irmh, rmh diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index a96a0be7e..09e74a4ef 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Step size adaptation""" -import math from typing import Callable, NamedTuple import jax @@ -262,22 +261,13 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: return rss_state.step_size - - - - - - - - -#### mclmc - class MCLMCAdaptationState(NamedTuple): - """Tunable parameters""" + """Tunable parameters for MCLMC""" L: float step_size: float + def effective_sample_size( input_array: ArrayLikeTree, chain_axis: int = 0, sample_axis: int = 1 ) -> Array: @@ -413,6 +403,7 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): return ess.squeeze() + def ess_corr(x): """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html shape(x) = (num_samples, d)""" @@ -513,93 +504,125 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) ess = ess_raw / tau_hat - ### my part (combine all dimensions): ### neff = ess.squeeze() / num_samples print("tau hat", ess, num_samples, neff) return 1.0 / jnp.average(1 / neff) def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): - """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" - - nonans = jnp.all(jnp.isfinite(xx)) - return nonans, *jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (xx, uu, ll, gg, eps_max, dK), (x, u, l, g, eps * 0.8, 0.)) - -def dynamics_adaptive(dynamics, state, L): - """One step of the dynamics with the adaptive stepsize""" - - x, u, l, g, E, Feps, Weps, eps_max, key = state + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" - eps = jnp.power(Feps/Weps, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. - eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - - state, info = dynamics(jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps) - - xx, uu, ll, gg = state - # ll, gg = -ll, -gg - kinetic_change = info.kinetic_change + nonans = jnp.all(jnp.isfinite(xx)) + return nonans, *jax.tree_util.tree_map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (xx, uu, ll, gg, eps_max, dK), + (x, u, l, g, eps * 0.8, 0.0), + ) - varEwanted = 5e-4 - sigma_xi= 1.5 - neff = 150 # effective number of steps used to determine the stepsize in the adaptive step - gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step +def dynamics_adaptive(dynamics, state, L): + """One step of the dynamics with the adaptive stepsize""" - # step updating - success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change) + x, u, l, g, E, Feps, Weps, eps_max, key = state + eps = jnp.power( + Feps / Weps, -1.0 / 6.0 + ) # We use the Var[E] = O(eps^6) relation here. + eps = (eps < eps_max) * eps + ( + eps > eps_max + ) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - DE = info.dE # energy difference - EE = E + DE # energy - # Warning: var = 0 if there were nans, but we will give it a very small weight - xi = ((DE ** 2) / (xx.shape[0] * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi - w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. - Feps = gamma * Feps + w * (xi/jnp.power(eps, 6.0)) # Kalman update the linear combinations - Weps = gamma * Weps + w + state, info = dynamics( + jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps + ) - return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success + xx, uu, ll, gg = state + # ll, gg = -ll, -gg + kinetic_change = info.kinetic_change -def tune12(kernel,x, u, l, g, random_key, L, eps, num_steps1, num_steps2): - """cheap hyperparameter tuning""" - - def step(state, outer_weight): - """one adaptive step of the dynamics""" - x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(kernel, state[0], L) - W, F1, F2 = state[1] - w = outer_weight * eps - zero_prevention = 1-outer_weight - F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter - F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter - W += w + varEwanted = 5e-4 + sigma_xi = 1.5 + neff = 150 # effective number of steps used to determine the stepsize in the adaptive step + gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step - return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps + # step updating + success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject( + x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change + ) - # we use the last num_steps2 to compute the diagonal preconditioner - outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + DE = info.dE # energy difference + EE = E + DE # energy + # Warning: var = 0 if there were nans, but we will give it a very small weight + xi = ( + (DE**2) / (xx.shape[0] * varEwanted) + ) + 1e-8 # 1e-8 is added to avoid divergences in log xi + w = jnp.exp( + -0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi)) + ) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. + Feps = gamma * Feps + w * ( + xi / jnp.power(eps, 6.0) + ) # Kalman update the linear combinations + Weps = gamma * Weps + w + + return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success + + +def tune12(kernel, x, u, l, g, random_key, L, eps, num_steps1, num_steps2): + """cheap hyperparameter tuning""" + + def step(state, outer_weight): + """one adaptive step of the dynamics""" + x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive( + kernel, state[0], L + ) + W, F1, F2 = state[1] + w = outer_weight * eps + zero_prevention = 1 - outer_weight + F1 = (W * F1 + w * x) / ( + W + w + zero_prevention + ) # Update with a Kalman filter + F2 = (W * F2 + w * jnp.square(x)) / ( + W + w + zero_prevention + ) # Update with a Kalman filter + W += w + + return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps + + # we use the last num_steps2 to compute the diagonal preconditioner + outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + # initial state + state = ( + (x, u, l, g, 0.0, jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), + (0.0, jnp.zeros(len(x)), jnp.zeros(len(x))), + ) + # run the steps + state, eps = jax.lax.scan( + step, init=state, xs=outer_weights, length=num_steps1 + num_steps2 + ) + # determine L + if num_steps2 != 0.0: + F1, F2 = state[1][1], state[1][2] + variances = F2 - jnp.square(F1) + sigma2 = jnp.average(variances) - #initial state - state = ((x, u, l, g, 0., jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), (0., jnp.zeros(len(x)), jnp.zeros(len(x)))) - # run the steps - state, eps = jax.lax.scan(step, init=state, xs= outer_weights, length= num_steps1 + num_steps2) - # determine L - if num_steps2 != 0.: - F1, F2 = state[1][1], state[1][2] - variances = F2 - jnp.square(F1) - sigma2 = jnp.average(variances) + L = jnp.sqrt(sigma2 * x.shape[0]) - L = jnp.sqrt(sigma2 * x.shape[0]) + xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state + return ( + L, + eps[-1], + MCLMCState(xx, uu, ll, gg), + ) # return the tuned hyperparameters and the final state - xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state - return L, eps[-1], MCLMCState(xx, uu, ll, gg) #return the tuned hyperparameters and the final state def tune3(kernel, state, rng_key, L, eps, num_steps): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - state, info = jax.lax.scan( lambda s, k: (kernel(k, s, L, eps)), state, jax.random.split(rng_key, num_steps) ) - + Lfactor = 0.4 ESS2 = effective_sample_size(info.transformed_x) neff = ESS2.squeeze() / info.transformed_x.shape[0] @@ -609,16 +632,36 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): Lnew = Lfactor * eps / ess_corr(info.transformed_x) return Lnew, state -def tune(kernel, num_steps: int, rng_key: PRNGKey, params : MCLMCAdaptationState) -> tuple[MCLMCAdaptationState, MCLMCState]: +def tune( + kernel, num_steps: int, rng_key: PRNGKey, params: MCLMCAdaptationState +) -> tuple[MCLMCAdaptationState, MCLMCState]: num_tune_step_ratio_1 = 0.1 num_tune_step_ratio_2 = 0.1 - x, u, l, g = jnp.array([0.1, 0.1]), jnp.array([-0.6755803, 0.73728645]), -0.010000001, -jnp.array([0.1, 0.1]) + x, u, l, g = ( + jnp.array([0.1, 0.1]), + jnp.array([-0.6755803, 0.73728645]), + -0.010000001, + -jnp.array([0.1, 0.1]), + ) tune1_key, tune2_key = jax.random.split(rng_key) - L, eps, state = tune12(kernel, x, u, l, g, tune1_key, params.L, params.step_size, int(num_steps * num_tune_step_ratio_1), int(num_steps * num_tune_step_ratio_1)) + L, eps, state = tune12( + kernel, + x, + u, + l, + g, + tune1_key, + params.L, + params.step_size, + int(num_steps * num_tune_step_ratio_1), + int(num_steps * num_tune_step_ratio_1), + ) - L, state = tune3(kernel, state, tune2_key, L, eps, int(num_steps * num_tune_step_ratio_2)) - return MCLMCAdaptationState(L, eps), state \ No newline at end of file + L, state = tune3( + kernel, state, tune2_key, L, eps, int(num_steps * num_tune_step_ratio_2) + ) + return MCLMCAdaptationState(L, eps), state diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 482e037e6..0dde8520f 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -247,6 +247,7 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return one_step + def minimal_norm(T, V): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator @@ -255,14 +256,25 @@ def step(state: IntegratorState, step_size): # V T V T V # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) - uu, r1 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g), state.momentum, - state.logdensity_grad) + uu, r1 = jax.tree_util.tree_map( + lambda u, g: V(step_size * lambda_c, u, g), + state.momentum, + state.logdensity_grad, + ) # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) - xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u), state.position, uu) - uu, r2 = jax.tree_util.tree_map(lambda u, g : V(step_size * (1 - 2 * lambda_c), u, g), uu, gg) - xx, ll, gg = jax.tree_util.tree_map(lambda x, u : T(step_size, x, 0.5 * u), xx, uu) - uu, r3 = jax.tree_util.tree_map(lambda u, g : V(step_size * lambda_c, u, g), uu, gg) + xx, ll, gg = jax.tree_util.tree_map( + lambda x, u: T(step_size, x, 0.5 * u), state.position, uu + ) + uu, r2 = jax.tree_util.tree_map( + lambda u, g: V(step_size * (1 - 2 * lambda_c), u, g), uu, gg + ) + xx, ll, gg = jax.tree_util.tree_map( + lambda x, u: T(step_size, x, 0.5 * u), xx, uu + ) + uu, r3 = jax.tree_util.tree_map( + lambda u, g: V(step_size * lambda_c, u, g), uu, gg + ) # kinetic energy change kinetic_change = (r1 + r2 + r3) * (uu.shape[0] - 1) @@ -272,11 +284,9 @@ def step(state: IntegratorState, step_size): return step - - def update_position_mclmc(grad_logp): - """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) - """ + """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" + def update(step_size, x, u): xx = x + step_size * u ll, gg = grad_logp(xx) @@ -284,6 +294,7 @@ def update(step_size, x, u): return update + def update_momentum_mclmc(step_size, u, g): """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) similar to the implementation: https://github.com/gregversteeg/esh_dynamics @@ -299,4 +310,3 @@ def update_momentum_mclmc(step_size, u, g): uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r - diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 30cd0f1af..1370af84e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -21,20 +21,21 @@ from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLike, PRNGKey -__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc", "Parameters"] +__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] MCLMCState = integrators.IntegratorState + class MCLMCInfo(NamedTuple): """Additional information on the MCLMC transition. - + transformed_x The value of the samples after a transformation (e.g. projection onto lower dim subspace) logdensity logdensity at given step dE energy difference - + """ transformed_x: Array @@ -46,7 +47,7 @@ class MCLMCInfo(NamedTuple): def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) # jax.debug.print("🤯 {x} initial momentum 🤯", x=random_unit_vector(rng_key, dim=x_initial.shape[0])) - + return MCLMCState( position=x_initial, momentum=random_unit_vector(rng_key, dim=x_initial.shape[0]), @@ -54,8 +55,8 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) + def build_kernel(grad_logp, integrator, transform): - """Build a HMC kernel. Parameters @@ -68,7 +69,7 @@ def build_kernel(grad_logp, integrator, transform): the momentum decoherence rate step_size step size of the integrator - + Returns ------- A kernel that takes a rng_key and a Pytree that contains the current state @@ -76,13 +77,17 @@ def build_kernel(grad_logp, integrator, transform): information about the transition. """ - step = integrator(T=integrators.update_position_mclmc(grad_logp), V=integrators.update_momentum_mclmc) + step = integrator( + T=integrators.update_position_mclmc(grad_logp), + V=integrators.update_momentum_mclmc, + ) - def kernel(rng_key: PRNGKey, state: MCLMCState, L : float, step_size : float) -> tuple[MCLMCState, MCLMCInfo]: + def kernel( + rng_key: PRNGKey, state: MCLMCState, L: float, step_size: float + ) -> tuple[MCLMCState, MCLMCInfo]: xx, uu, ll, gg, kinetic_change = step(state, step_size) # jax.debug.print("🤯 {x} new 🤯", x=(kinetic_change, ll, state.logdensity)) - - + dim = xx.shape[0] # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) @@ -92,13 +97,12 @@ def kernel(rng_key: PRNGKey, state: MCLMCState, L : float, step_size : float) - transformed_x=transform(xx), logdensity=ll, dE=kinetic_change - ll + state.logdensity, - kinetic_change=kinetic_change + kinetic_change=kinetic_change, ) return kernel - class mclmc: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -165,7 +169,7 @@ def __new__( # type: ignore[misc] kernel = cls.build_kernel(grad_logp, integrator, transform) def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) + return kernel(rng_key, state, L, step_size) def init_fn(position: ArrayLike): return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) @@ -184,10 +188,7 @@ def random_unit_vector(rng_key, dim): return u - - def partially_refresh_momentum(u, rng_key, nu): """Adds a small noise to u and normalizes.""" z = nu * jax.random.normal(rng_key, shape=(u.shape[0],)) return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - From 6266bc4b615a882c7b9ef7b0393d72ded3606e39 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 Nov 2023 18:01:25 +0000 Subject: [PATCH 30/78] rename T and V --- blackjax/mcmc/integrators.py | 12 ++++++------ blackjax/mcmc/mclmc.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 0dde8520f..7ccc84f64 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -248,7 +248,7 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return one_step -def minimal_norm(T, V): +def minimal_norm(update_position, update_momentum): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator def step(state: IntegratorState, step_size): @@ -257,23 +257,23 @@ def step(state: IntegratorState, step_size): # V T V T V # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) uu, r1 = jax.tree_util.tree_map( - lambda u, g: V(step_size * lambda_c, u, g), + lambda u, g: update_momentum(step_size * lambda_c, u, g), state.momentum, state.logdensity_grad, ) # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) xx, ll, gg = jax.tree_util.tree_map( - lambda x, u: T(step_size, x, 0.5 * u), state.position, uu + lambda x, u: update_position(step_size, x, 0.5 * u), state.position, uu ) uu, r2 = jax.tree_util.tree_map( - lambda u, g: V(step_size * (1 - 2 * lambda_c), u, g), uu, gg + lambda u, g: update_momentum(step_size * (1 - 2 * lambda_c), u, g), uu, gg ) xx, ll, gg = jax.tree_util.tree_map( - lambda x, u: T(step_size, x, 0.5 * u), xx, uu + lambda x, u: update_position(step_size, x, 0.5 * u), xx, uu ) uu, r3 = jax.tree_util.tree_map( - lambda u, g: V(step_size * lambda_c, u, g), uu, gg + lambda u, g: update_momentum(step_size * lambda_c, u, g), uu, gg ) # kinetic energy change diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 1370af84e..e0c2937e5 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -78,8 +78,8 @@ def build_kernel(grad_logp, integrator, transform): """ step = integrator( - T=integrators.update_position_mclmc(grad_logp), - V=integrators.update_momentum_mclmc, + update_position=integrators.update_position_mclmc(grad_logp), + update_momentum=integrators.update_momentum_mclmc, ) def kernel( From ca984e757e0ab44b1597f7bbf6584a87b8128dbd Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 Nov 2023 18:21:04 +0000 Subject: [PATCH 31/78] uniformity wip --- blackjax/mcmc/integrators.py | 45 ++++++++++++++++++++---------------- blackjax/mcmc/mclmc.py | 4 ++-- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 7ccc84f64..e627ec4f1 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Symplectic, time-reversible, integrators for Hamiltonian trajectories.""" +import functools from typing import Callable, NamedTuple import jax @@ -248,7 +249,7 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return one_step -def minimal_norm(update_position, update_momentum): +def minimal_norm(O1, O2): lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator def step(state: IntegratorState, step_size): @@ -257,23 +258,24 @@ def step(state: IntegratorState, step_size): # V T V T V # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) uu, r1 = jax.tree_util.tree_map( - lambda u, g: update_momentum(step_size * lambda_c, u, g), + lambda x, u, g: O1(step_size * lambda_c)(x, u, g), + state.position, state.momentum, state.logdensity_grad, ) # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) - xx, ll, gg = jax.tree_util.tree_map( - lambda x, u: update_position(step_size, x, 0.5 * u), state.position, uu + xx, _, ll, gg = jax.tree_util.tree_map( + lambda x, u: O2(step_size, x, u), state.position, uu*0.5 ) uu, r2 = jax.tree_util.tree_map( - lambda u, g: update_momentum(step_size * (1 - 2 * lambda_c), u, g), uu, gg + lambda x, u, g: O1(step_size * (1 - 2 * lambda_c))(x, u, g), xx, uu, gg ) - xx, ll, gg = jax.tree_util.tree_map( - lambda x, u: update_position(step_size, x, 0.5 * u), xx, uu + xx, _, ll, gg = jax.tree_util.tree_map( + lambda x, u: O2(step_size, x, u), xx, uu*0.5 ) uu, r3 = jax.tree_util.tree_map( - lambda u, g: update_momentum(step_size * lambda_c, u, g), uu, gg + lambda x, u, g: O1(step_size * lambda_c)(x, u, g), xx, uu, gg ) # kinetic energy change @@ -290,23 +292,26 @@ def update_position_mclmc(grad_logp): def update(step_size, x, u): xx = x + step_size * u ll, gg = grad_logp(xx) - return xx, ll, gg + return xx, u, ll, gg return update -def update_momentum_mclmc(step_size, u, g): +def update_momentum_mclmc(step_size): """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) similar to the implementation: https://github.com/gregversteeg/esh_dynamics There are no exponentials e^delta, which prevents overflows when the gradient norm is large. """ - g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) - e = g / g_norm - ue = jax.numpy.dot(u, e) - # jax.debug.print("🤯 {x} inside momentum update 🤯", x=(ue)) - dim = u.shape[0] - delta = step_size * g_norm / (dim - 1) - zeta = jax.numpy.exp(-delta) - uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u - delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) - return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r + + def update(x, u, g): + g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) + e = g / g_norm + ue = jax.numpy.dot(u, e) + # jax.debug.print("🤯 {x} inside momentum update 🤯", x=(ue)) + dim = u.shape[0] + delta = step_size * g_norm / (dim - 1) + zeta = jax.numpy.exp(-delta) + uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u + delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) + return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r + return update \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e0c2937e5..0a92f0577 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -78,8 +78,8 @@ def build_kernel(grad_logp, integrator, transform): """ step = integrator( - update_position=integrators.update_position_mclmc(grad_logp), - update_momentum=integrators.update_momentum_mclmc, + O1=integrators.update_momentum_mclmc, + O2=integrators.update_position_mclmc(grad_logp), ) def kernel( From 59ffb213db296be14859c846ca4622575097ac23 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 Nov 2023 20:08:31 +0000 Subject: [PATCH 32/78] make uniform implementation of integrators --- blackjax/adaptation/step_size.py | 141 ++----------------------------- blackjax/diagnostics.py | 2 + blackjax/mcmc/integrators.py | 139 ++++++++++++++++++++++-------- blackjax/mcmc/mclmc.py | 7 +- 4 files changed, 117 insertions(+), 172 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 09e74a4ef..c67c500f3 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from scipy.fft import next_fast_len +from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.hmc import HMCState from blackjax.mcmc.mclmc import MCLMCState @@ -268,141 +269,6 @@ class MCLMCAdaptationState(NamedTuple): step_size: float -def effective_sample_size( - input_array: ArrayLikeTree, chain_axis: int = 0, sample_axis: int = 1 -) -> Array: - """Compute estimate of the effective sample size (ess). - - Parameters - ---------- - input_array: - An array representing multiple chains of MCMC samples. The array must - contains a chain dimension and a sample dimension. - chain_axis - The axis indicating the multiple chains. Default to 0. - sample_axis - The axis indicating a single chain of MCMC samples. Default to 1. - - Returns - ------- - NDArray of the resulting statistics (ess), with the chain and sample dimensions squeezed. - - Notes - ----- - The basic ess (:math:`N_{\\mathit{eff}}`) diagnostic is computed by: - - .. math:: \\hat{N}_{\\mathit{eff}} = \\frac{MN}{\\hat{\\tau}} - - .. math:: \\hat{\\tau} = -1 + 2 \\sum_{t'=0}^K \\hat{P}_{t'} - - where :math:`M` is the number of chains, :math:`N` the number of draws, - :math:`\\hat{\\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and - :math:`K` is the last integer for which :math:`\\hat{P}_{K} = \\hat{\\rho}_{2K} + - \\hat{\\rho}_{2K+1}` is still positive :cite:p:`stan_ess,gelman1995bayesian`. - - The current implementation is similar to Stan, which uses Geyer's initial monotone sequence - criterion :cite:p:`geyer1992practical,geyer2011introduction`. - - """ - input_shape = input_array.shape - sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis - num_chains = input_shape[chain_axis] - num_samples = input_shape[sample_axis] - assert ( - num_chains > 1 - ), "effective_sample_size as implemented only works for two or more chains." - - mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True) - print("mean 1", mean_across_chain) - # Compute autocovariance estimates for every lag for the input array using FFT. - centered_array = input_array - mean_across_chain - m = next_fast_len(2 * num_samples) - ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=sample_axis) - ifft_ary *= jnp.conjugate(ifft_ary) - autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=sample_axis) - autocov_value = ( - jnp.take(autocov_value, jnp.arange(num_samples), axis=sample_axis) / num_samples - ) - mean_autocov_var = autocov_value.mean(chain_axis, keepdims=True) - mean_var0 = ( - jnp.take(mean_autocov_var, jnp.array([0]), axis=sample_axis) - * num_samples - / (num_samples - 1.0) - ) - weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - weighted_var = jax.lax.cond( - num_chains > 1, - lambda _: weighted_var - + mean_across_chain.var(axis=chain_axis, ddof=1, keepdims=True), - lambda _: weighted_var, - operand=None, - ) - - # Geyer's initial positive sequence - num_samples_even = num_samples - num_samples % 2 - mean_autocov_var_tp1 = jnp.take( - mean_autocov_var, jnp.arange(1, num_samples_even), axis=sample_axis - ) - rho_hat = jnp.concatenate( - [ - jnp.ones_like(mean_var0), - 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, - ], - axis=sample_axis, - ) - - rho_hat = jnp.moveaxis(rho_hat, sample_axis, 0) - rho_hat_even = rho_hat[0::2] - rho_hat_odd = rho_hat[1::2] - - mask0 = (rho_hat_even + rho_hat_odd) > 0.0 - carry_cond = jnp.ones_like(mask0[0]) - max_t = jnp.zeros_like(mask0[0], dtype=int) - - def positive_sequence_body_fn(state, mask_t): - t, carry_cond, max_t = state - next_mask = carry_cond & mask_t - next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) - return (t + 1, next_mask, next_max_t), next_mask - - (*_, max_t_next), mask = jax.lax.scan( - positive_sequence_body_fn, (0, carry_cond, max_t), mask0 - ) - indices = jnp.indices(max_t_next.shape) - indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) - rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) - # improve estimation - mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) - rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) - - # Geyer's initial monotone sequence - def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): - update_mask = rho_hat_sum_t > rho_hat_sum_tm1 - next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) - return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) - - rho_hat_sum = rho_hat_even + rho_hat_odd - _, (update_mask, update_value) = jax.lax.scan( - monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum - ) - - rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) - rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) - - # compute effective sample size - ess_raw = num_chains * num_samples - tau_hat = ( - -1.0 - + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) - - rho_hat_even_final[indices] - ) - - tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) - ess = ess_raw / tau_hat - print("tau hat", ess) - - return ess.squeeze() - def ess_corr(x): """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html @@ -414,6 +280,9 @@ def ess_corr(x): ] ) + print(input_array.shape,"input shape 2") + + num_chains = 1 # input_array.shape[0] num_samples = input_array.shape[1] @@ -435,6 +304,8 @@ def ess_corr(x): / (num_samples - 1.0) ) weighted_var = mean_var0 * (num_samples - 1.0) / num_samples + jax.debug.print("🤯 {x} weighted_var 2 🤯", x=weighted_var) + weighted_var = jax.lax.cond( num_chains > 1, lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index da861d9b1..2f9bad5ba 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -112,6 +112,7 @@ def effective_sample_size( """ input_shape = input_array.shape + print(input_shape, "input shape") sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis num_chains = input_shape[chain_axis] num_samples = input_shape[sample_axis] @@ -143,6 +144,7 @@ def effective_sample_size( lambda _: weighted_var, operand=None, ) + jax.debug.print("🤯 {x} weighted_var 🤯", x=weighted_var) # Geyer's initial positive sequence num_samples_even = num_samples - num_samples % 2 diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index e627ec4f1..c06dcf94a 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -13,6 +13,7 @@ # limitations under the License. """Symplectic, time-reversible, integrators for Hamiltonian trajectories.""" import functools +import itertools from typing import Callable, NamedTuple import jax @@ -249,39 +250,39 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return one_step -def minimal_norm(O1, O2): - lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator +# def palindromic_sequence(s): + +# return lambda O1, O2 : list(itertools.zip_longest(itertools.cycle([O1, O2]), s)) + +# print(palindromic_sequence([1,2])('f','g')) + +# minimal_norm_sequence = palindromic_sequence([0.1931833275037836, 1., 1.- 2.*0.1931833275037836]) + +minimal_norm_sequence = lambda O1, O2 : [ + (O1, 0.1931833275037836), + (O2, 1.), + (O1, 1.- 2.*0.1931833275037836), + (O2, 1.), + (O1, 0.1931833275037836), + ] + +def make_integrator(O1, O2, order): + + sequence = order(O1,O2) + def step(state: IntegratorState, step_size): """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" - # V T V T V - # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) - uu, r1 = jax.tree_util.tree_map( - lambda x, u, g: O1(step_size * lambda_c)(x, u, g), - state.position, - state.momentum, - state.logdensity_grad, - ) - # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) - - xx, _, ll, gg = jax.tree_util.tree_map( - lambda x, u: O2(step_size, x, u), state.position, uu*0.5 - ) - uu, r2 = jax.tree_util.tree_map( - lambda x, u, g: O1(step_size * (1 - 2 * lambda_c))(x, u, g), xx, uu, gg - ) - xx, _, ll, gg = jax.tree_util.tree_map( - lambda x, u: O2(step_size, x, u), xx, uu*0.5 - ) - uu, r3 = jax.tree_util.tree_map( - lambda x, u, g: O1(step_size * lambda_c)(x, u, g), xx, uu, gg - ) + xx, uu, ll, gg = state + total_r = 0 - # kinetic energy change - kinetic_change = (r1 + r2 + r3) * (uu.shape[0] - 1) + for O,factor in sequence: + xx, uu, ll, gg, r = jax.tree_util.tree_map(O(step_size*factor), xx, uu, ll, gg) + total_r += r - return xx, uu, ll, gg, kinetic_change + kinetic_change = jax.numpy.sum(total_r) + return IntegratorState(xx, uu, ll, gg), kinetic_change return step @@ -289,12 +290,14 @@ def step(state: IntegratorState, step_size): def update_position_mclmc(grad_logp): """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" - def update(step_size, x, u): + def update(step_size, x, u, l, g): + u *= 0.5 xx = x + step_size * u ll, gg = grad_logp(xx) - return xx, u, ll, gg + u *= 2. + return xx, u, ll, gg, 0 - return update + return lambda O : functools.partial(update,O) def update_momentum_mclmc(step_size): @@ -303,7 +306,7 @@ def update_momentum_mclmc(step_size): There are no exponentials e^delta, which prevents overflows when the gradient norm is large. """ - def update(x, u, g): + def update(x, u, l, g): g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) e = g / g_norm ue = jax.numpy.dot(u, e) @@ -313,5 +316,75 @@ def update(x, u, g): zeta = jax.numpy.exp(-delta) uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) - return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r - return update \ No newline at end of file + return x, uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), l, g, delta_r + return update + +minimal_norm = lambda O1, O2: make_integrator(O1, O2, minimal_norm_sequence) + + +# def minimal_norm(O1, O2): +# lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator + +# def step(state: IntegratorState, step_size): +# """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" + +# # V T V T V +# # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) +# uu, r1 = jax.tree_util.tree_map( +# lambda x, u, g: O1(step_size * lambda_c)(x, u, g), +# state.position, +# state.momentum, +# state.logdensity_grad, +# ) +# # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) + +# xx, _, ll, gg = jax.tree_util.tree_map( +# lambda x, u: O2(step_size, x, u), state.position, uu*0.5 +# ) +# uu, r2 = jax.tree_util.tree_map( +# lambda x, u, g: O1(step_size * (1 - 2 * lambda_c))(x, u, g), xx, uu, gg +# ) +# xx, _, ll, gg = jax.tree_util.tree_map( +# lambda x, u: O2(step_size, x, u), xx, uu*0.5 +# ) +# uu, r3 = jax.tree_util.tree_map( +# lambda x, u, g: O1(step_size * lambda_c)(x, u, g), xx, uu, gg +# ) + +# # kinetic energy change +# kinetic_change = (r1 + r2 + r3) * (uu.shape[0] - 1) + +# return xx, uu, ll, gg, kinetic_change + +# return step + + +# def update_position_mclmc(grad_logp): +# """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" + +# def update(step_size, x, u): +# xx = x + step_size * u +# ll, gg = grad_logp(xx) +# return xx, u, ll, gg + +# return update + + +# def update_momentum_mclmc(step_size): +# """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) +# similar to the implementation: https://github.com/gregversteeg/esh_dynamics +# There are no exponentials e^delta, which prevents overflows when the gradient norm is large. +# """ + +# def update(x, u, g): +# g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) +# e = g / g_norm +# ue = jax.numpy.dot(u, e) +# # jax.debug.print("🤯 {x} inside momentum update 🤯", x=(ue)) +# dim = u.shape[0] +# delta = step_size * g_norm / (dim - 1) +# zeta = jax.numpy.exp(-delta) +# uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u +# delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) +# return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r +# return update \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 0a92f0577..49d428949 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -85,9 +85,8 @@ def build_kernel(grad_logp, integrator, transform): def kernel( rng_key: PRNGKey, state: MCLMCState, L: float, step_size: float ) -> tuple[MCLMCState, MCLMCInfo]: - xx, uu, ll, gg, kinetic_change = step(state, step_size) - # jax.debug.print("🤯 {x} new 🤯", x=(kinetic_change, ll, state.logdensity)) - + (xx, uu, ll, gg), kinetic_change = step(state, step_size) + dim = xx.shape[0] # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) @@ -97,7 +96,7 @@ def kernel( transformed_x=transform(xx), logdensity=ll, dE=kinetic_change - ll + state.logdensity, - kinetic_change=kinetic_change, + kinetic_change=kinetic_change*(uu.shape[0] - 1), ) return kernel From 8f9214f0da6e08f6a6d124f8dd22e434e12b818e Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 Nov 2023 07:17:54 -0500 Subject: [PATCH 33/78] make uniform implementation of integrators --- blackjax/mcmc/integrators.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index c06dcf94a..0f4d10eb1 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -250,22 +250,11 @@ def one_step(state: IntegratorState, step_size: float) -> IntegratorState: return one_step -# def palindromic_sequence(s): - -# return lambda O1, O2 : list(itertools.zip_longest(itertools.cycle([O1, O2]), s)) - -# print(palindromic_sequence([1,2])('f','g')) - -# minimal_norm_sequence = palindromic_sequence([0.1931833275037836, 1., 1.- 2.*0.1931833275037836]) - -minimal_norm_sequence = lambda O1, O2 : [ - (O1, 0.1931833275037836), - (O2, 1.), - (O1, 1.- 2.*0.1931833275037836), - (O2, 1.), - (O1, 0.1931833275037836), - ] - +def palindromic_sequence(s): + # symetrize + s = s[:-1] + s[::-1] + # zip with alternating operators + return lambda O1, O2 : list(zip(itertools.cycle([O1, O2]), s)) def make_integrator(O1, O2, order): @@ -286,7 +275,6 @@ def step(state: IntegratorState, step_size): return step - def update_position_mclmc(grad_logp): """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" @@ -321,6 +309,10 @@ def update(x, u, l, g): minimal_norm = lambda O1, O2: make_integrator(O1, O2, minimal_norm_sequence) +minimal_norm_sequence = palindromic_sequence([0.1931833275037836, 1., 1.- 2.*0.1931833275037836]) +leapfrog_sequence = palindromic_sequence([0.5, 1.]) +yoshida_sequence = palindromic_sequence([0.11888010966548, 0.29619504261126, 0.5 - 0.11888010966548, 1 - 2 * 0.29619504261126]) + # def minimal_norm(O1, O2): # lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator From b2e3b8e42786e0772127f5d381c9a99db027e9f6 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 Nov 2023 09:26:08 -0500 Subject: [PATCH 34/78] fix minimal norm integrator --- blackjax/mcmc/integrators.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 0f4d10eb1..1ead49745 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -279,10 +279,8 @@ def update_position_mclmc(grad_logp): """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" def update(step_size, x, u, l, g): - u *= 0.5 xx = x + step_size * u ll, gg = grad_logp(xx) - u *= 2. return xx, u, ll, gg, 0 return lambda O : functools.partial(update,O) @@ -309,7 +307,7 @@ def update(x, u, l, g): minimal_norm = lambda O1, O2: make_integrator(O1, O2, minimal_norm_sequence) -minimal_norm_sequence = palindromic_sequence([0.1931833275037836, 1., 1.- 2.*0.1931833275037836]) +minimal_norm_sequence = palindromic_sequence([0.1931833275037836, 0.5, 1.- 2.*0.1931833275037836]) leapfrog_sequence = palindromic_sequence([0.5, 1.]) yoshida_sequence = palindromic_sequence([0.11888010966548, 0.29619504261126, 0.5 - 0.11888010966548, 1 - 2 * 0.29619504261126]) From 2fb229388939397bbfdcd295741bd37b1a754a16 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 Nov 2023 16:09:17 -0500 Subject: [PATCH 35/78] add warning to tune3 --- blackjax/adaptation/step_size.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index c67c500f3..85d83b57d 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -13,6 +13,7 @@ # limitations under the License. """Step size adaptation""" from typing import Callable, NamedTuple +import warnings import jax import jax.numpy as jnp @@ -500,7 +501,11 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): print("neff", neff, info.transformed_x.shape[0]) ESS_alt = 1.0 / jnp.average(1 / neff) print(ess_corr(info.transformed_x), ESS_alt, "\n\nESSse\n\n") - Lnew = Lfactor * eps / ess_corr(info.transformed_x) + ESS = ess_corr(info.transformed_x) + if ESS * num_steps <= 10: + warnings.warn("tune3 cannot be expected to work with 10 or fewer effective samples") + + Lnew = Lfactor * eps / ESS return Lnew, state From 59e442432629e93c4d95964df83e8637f7792a40 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sun, 19 Nov 2023 18:32:27 +0100 Subject: [PATCH 36/78] Refactor integrators.py to make it more general. Also add momentum update based on Esh dynamics Co-authored-by: Reuben Cohn-Gordon --- blackjax/mcmc/integrators.py | 396 ++++++++++++++++++++++----------- tests/mcmc/test_integrators.py | 19 +- 2 files changed, 280 insertions(+), 135 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 09946e9a3..2c224c9e8 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -15,6 +15,8 @@ from typing import Callable, NamedTuple import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.mcmc.metrics import EuclideanKineticEnergy from blackjax.types import ArrayTree @@ -38,11 +40,154 @@ class IntegratorState(NamedTuple): Integrator = Callable[[IntegratorState, float], IntegratorState] +def generalized_symplectic_integrator( + momentum_update_fn: Callable, + position_update_fn: Callable, + coefficients: list[float], + format_output_fn: Callable = lambda x: x, +): + """Generalized symplectic integrator. + + The generalized symplectic integrator performs numerical integration + of a Hamiltonian system by alernating between momentum and position updates. + The update scheme is decided by the coefficients and palindromic, i.e. + the coefficients of the update scheme should be symmetric with respect to the + middle of the scheme. + [TODO]: expand this with information in https://github.com/blackjax-devs/blackjax/issues/587 + + Parameters + ---------- + momentum_update_fn + Function that updates the momentum. + position_update_fn + Function that updates the position. + coefficients + Coefficients of the integrator. + format_output_fn + Function that formats the output of the integrator. + + Returns + ------- + integrator + Integrator function. + """ + + def one_step(state: IntegratorState, step_size: float): + position, momentum, _, logdensity_grad = state + # auxiliary infomation generated during integration for diagnostics. It is updated + # by the momentum_update_fn and position_update_fn at each call + momentum_update_info = None + position_update_info = None + for i, coef in enumerate(coefficients[:-1]): + if i % 2 == 0: + momentum, kinetic_grad, momentum_update_info = momentum_update_fn( + momentum, + logdensity_grad, + step_size, + coef, + momentum_update_info, + is_last_call=False, + ) + else: + ( + position, + logdensity, + logdensity_grad, + position_update_info, + ) = position_update_fn( + position, + kinetic_grad, + step_size, + coef, + position_update_info, + ) + # Separate the last steps to short circuit the computation of the kinetic_grad + momentum, kinetic_grad, momentum_update_info = momentum_update_fn( + momentum, + logdensity_grad, + step_size, + coefficients[-1], + momentum_update_info, + is_last_call=True, + ) + return format_output_fn( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, + ) + + return one_step + + def new_integrator_state(logdensity_fn, position, momentum): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) return IntegratorState(position, momentum, logdensity, logdensity_grad) +def euclidean_position_update_fn(logdensity_fn: Callable): + logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) + + def update( + position: ArrayTree, + kinetic_grad: ArrayTree, + step_size: float, + coef: float, + auxiliary_info=None, + ): + del auxiliary_info + new_position = jax.tree_util.tree_map( + lambda x, grad: x + step_size * coef * grad, + position, + kinetic_grad, + ) + logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) + return new_position, logdensity, logdensity_grad, None + + return update + + +def euclidean_momentum_update_fn(kinetic_energy_fn: EuclideanKineticEnergy): + kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) + + def update( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + auxiliary_info=None, + is_last_call=False, + ): + del auxiliary_info + new_momentum = jax.tree_util.tree_map( + lambda x, grad: x + step_size * coef * grad, + momentum, + logdensity_grad, + ) + if is_last_call: + return new_momentum, None, None + kinetic_grad = kinetic_energy_grad_fn(new_momentum) + return new_momentum, kinetic_grad, None + + return update + + +def format_euclidean_state_output( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, +): + del kinetic_grad, position_update_info, momentum_update_info + return IntegratorState(position, momentum, logdensity, logdensity_grad) + + def velocity_verlet( logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy, @@ -68,43 +213,21 @@ def velocity_verlet( a1 = 0 b1 = 0.5 a2 = 1 - 2 * a1 - - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a2 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - return IntegratorState(position, momentum, logdensity, logdensity_grad) - + cofficients = [b1, a2, b1] + position_update_fn = euclidean_position_update_fn(logdensity_fn) + momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) + one_step = generalized_symplectic_integrator( + momentum_update_fn, + position_update_fn, + cofficients, + format_output_fn=format_euclidean_state_output, + ) return one_step def mclachlan( logdensity_fn: Callable, - kinetic_energy_fn: Callable, + kinetic_energy_fn: EuclideanKineticEnergy, ) -> Integrator: """Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. @@ -115,61 +238,25 @@ def mclachlan( and derives different values. """ - b1 = 0.1932 + b1 = 0.1931833275037836 a1 = 0.5 b2 = 1 - 2 * b1 - - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - return IntegratorState(position, momentum, logdensity, logdensity_grad) + cofficients = [b1, a1, b2, a1, b1] + position_update_fn = euclidean_position_update_fn(logdensity_fn) + momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) + one_step = generalized_symplectic_integrator( + momentum_update_fn, + position_update_fn, + cofficients, + format_output_fn=format_euclidean_state_output, + ) return one_step def yoshida( logdensity_fn: Callable, - kinetic_energy_fn: Callable, + kinetic_energy_fn: EuclideanKineticEnergy, ) -> Integrator: """Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` @@ -184,65 +271,108 @@ def yoshida( a1 = 0.29619504261126 b2 = 0.5 - b1 a2 = 1 - 2 * a1 + cofficients = [b1, a1, b2, a2, b2, a1, b1] + position_update_fn = euclidean_position_update_fn(logdensity_fn) + momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) + one_step = generalized_symplectic_integrator( + momentum_update_fn, + position_update_fn, + cofficients, + format_output_fn=format_euclidean_state_output, + ) - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) + return one_step - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a2 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) +# Intergrators with non Euclidean updates +def esh_dynamics_momentum_update_one_step( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + previous_kinetic_energy_change=None, + is_last_call=False, +): + """Momentum update based on Esh dynamics. + + [TODO]: update this docstring with proper references and citations. + The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) + similar to the implementation: https://github.com/gregversteeg/esh_dynamics + There are no exponentials e^delta, which prevents overflows when the gradient norm is large. + """ - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + gradient_norm = jnp.sqrt(jnp.sum(jnp.square(flatten_grads))) + normalized_gradient = -flatten_grads / gradient_norm + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_norm = new_momentum / jnp.sqrt(jnp.sum(jnp.square(new_momentum))) + kinetic_energy_change = ( + delta + - jnp.log(2) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) + next_momentum = unravel_fn(new_momentum_norm) + if previous_kinetic_energy_change is not None: + kinetic_energy_change += previous_kinetic_energy_change + if is_last_call: + kinetic_energy_change *= dims - 1 + return next_momentum, next_momentum, kinetic_energy_change + + +def format_noneuclidean_state_output( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, +): + del kinetic_grad, position_update_info + return ( + IntegratorState(position, momentum, logdensity, logdensity_grad), + momentum_update_info, + ) + + +def non_euclidean_leapfrog(logdensity_fn: Callable, *args, **kwargs) -> Callable: + """Leapfrog integrator with non Euclidean updates. + + Similar update scheme as velocity_verlet, but with non Euclidean updates of the momentum. + """ + cofficients = [0.5, 1.0, 0.5] + position_update_fn = euclidean_position_update_fn(logdensity_fn) + one_step = generalized_symplectic_integrator( + esh_dynamics_momentum_update_one_step, + position_update_fn, + cofficients, + format_output_fn=format_noneuclidean_state_output, + ) + return one_step - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - return IntegratorState(position, momentum, logdensity, logdensity_grad) +def minimal_norm(logdensity_fn: Callable, *args, **kwargs) -> Callable: + """minimal_norm integrator with non Euclidean updates. + Similar update scheme as mclachlan, but with non Euclidean updates of the momentum. + """ + b1 = 0.1931833275037836 + a1 = 0.5 + b2 = 1 - 2 * b1 + cofficients = [b1, a1, b2, a1, b1] + position_update_fn = euclidean_position_update_fn(logdensity_fn) + one_step = generalized_symplectic_integrator( + esh_dynamics_momentum_update_one_step, + position_update_fn, + cofficients, + format_output_fn=format_noneuclidean_state_output, + ) return one_step diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 68f1dbd88..56ab54458 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -51,6 +51,11 @@ def kinetic_energy(p): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "non_euclidean_leapfrog": { + "algorithm": integrators.non_euclidean_leapfrog, + "precision": 1e-4, + }, + "minimal_norm": {"algorithm": integrators.minimal_norm, "precision": 1e-5}, } @@ -101,7 +106,13 @@ class IntegratorTest(chex.TestCase): @parameterized.parameters( itertools.product( ["free_fall", "harmonic_oscillator", "planetary_motion"], - ["velocity_verlet", "mclachlan", "yoshida"], + [ + "velocity_verlet", + "mclachlan", + "yoshida", + "non_euclidean_leapfrog", + "minimal_norm", + ], ) ) def test_integrator(self, example_name, integrator_name): @@ -120,10 +131,14 @@ def test_integrator(self, example_name, integrator_name): initial_state = integrators.IntegratorState( q, p, neg_potential(q), jax.grad(neg_potential)(q) ) + if integrator_name in ["non_euclidean_leapfrog", "minimal_norm"]: + one_step = lambda _, state: step(state, step_size)[0] + else: + one_step = lambda _, state: step(state, step_size) final_state = jax.lax.fori_loop( 0, example["num_steps"], - lambda _, state: step(state, step_size), + one_step, initial_state, ) From 6684413ccb30132394d61fb257621821bbad1428 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 Nov 2023 16:40:10 -0500 Subject: [PATCH 37/78] temp: explore --- explore.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/explore.py b/explore.py index 14309c886..c19beb638 100644 --- a/explore.py +++ b/explore.py @@ -1,3 +1,4 @@ +import math from typing import NamedTuple from chex import Array import jax @@ -5,15 +6,13 @@ from scipy.fftpack import next_fast_len # type: ignore import blackjax -from blackjax.adaptation.step_size import tune +from blackjax.adaptation.step_size import MCLMCAdaptationState, tune from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import minimal_norm from blackjax.mcmc.mclmc import MCLMCState, build_kernel # from blackjax.diagnostics import effective_sample_size from blackjax.types import PRNGKey - - def logdensity_fn(x): return -0.5 * jnp.sum(jnp.square(x)) @@ -21,11 +20,8 @@ def logdensity_fn(x): def run_sampling_algorithm( sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key ): - # keys = jax.random.split(rng_key, num_steps) - keys = jnp.array([jax.random.PRNGKey(0)]*num_steps) + keys = jax.random.split(rng_key, num_steps) state = sampling_algorithm.init(initial_val) - print("\n\n", state.position, "\n\n") - print("\n\n", state.momentum, "\n\n") _, info = jax.lax.scan( lambda s, k: (sampling_algorithm.step(k, s)), state, keys ) @@ -35,35 +31,32 @@ def run_sampling_algorithm( key = jax.random.PRNGKey(0) main_key, tune_key = jax.random.split(key) +num_steps = 10000 +dim = 2 -gr = jax.value_and_grad(logdensity_fn) -kernel = build_kernel(grad_logp=gr, - integrator=minimal_norm, transform=lambda x:x) +params, state = tune( + params = MCLMCAdaptationState(L = math.sqrt(dim), step_size = math.sqrt(dim) * 0.4), + kernel=build_kernel(grad_logp=jax.value_and_grad(logdensity_fn), + integrator=minimal_norm, transform=lambda x:x), num_steps=num_steps, rng_key=tune_key) - -L, eps, state = (tune(kernel, num_steps=100, rng_key=tune_key)) -print("L, eps post tuning", L, eps) -raise Exception mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, transform=lambda x: x, - # L=0.56568545, step_size=1.4142135, inverse_mass_matrix=jnp.array([1.0, 1.0] - step_size=0.56568545, L=1.4142135, + L=params.L, step_size=params.step_size ) - out = run_sampling_algorithm( sampling_algorithm=mclmc, - num_steps=100, - initial_val=jnp.array([0.1, 0.1]), + num_steps=num_steps, + initial_val=state.position, rng_key=main_key, ) print(jnp.mean(out.transformed_x, axis=0)) -# print(logdensity_fn(jnp.array([0.1, 0.1]))) -# print(out) +print(f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}") +assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 -assert jnp.array_equal(jnp.mean(out.transformed_x, axis=0), [-1.2130139, 1.5367734]) +assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) From 4284092901346e7aef0df8fe8dbe87aa758f93bb Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Nov 2023 07:01:40 +0100 Subject: [PATCH 38/78] Refactor to use integrator generation functions --- blackjax/mcmc/integrators.py | 215 +++++++++++++-------------------- tests/mcmc/test_integrators.py | 11 +- 2 files changed, 94 insertions(+), 132 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 2c224c9e8..de07b95f6 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -74,8 +74,8 @@ def generalized_symplectic_integrator( def one_step(state: IntegratorState, step_size: float): position, momentum, _, logdensity_grad = state - # auxiliary infomation generated during integration for diagnostics. It is updated - # by the momentum_update_fn and position_update_fn at each call + # auxiliary infomation generated during integration for diagnostics. It is + # updated by the momentum_update_fn and position_update_fn at each call. momentum_update_info = None position_update_info = None for i, coef in enumerate(coefficients[:-1]): @@ -101,7 +101,7 @@ def one_step(state: IntegratorState, step_size: float): coef, position_update_info, ) - # Separate the last steps to short circuit the computation of the kinetic_grad + # Separate the last steps to short circuit the computation of the kinetic_grad. momentum, kinetic_grad, momentum_update_info = momentum_update_fn( momentum, logdensity_grad, @@ -188,100 +188,74 @@ def format_euclidean_state_output( return IntegratorState(position, momentum, logdensity, logdensity_grad) -def velocity_verlet( - logdensity_fn: Callable, - kinetic_energy_fn: EuclideanKineticEnergy, -) -> Integrator: - """The velocity Verlet (or Verlet-Störmer) integrator. - - The velocity Verlet is a two-stage palindromic integrator :cite:p:`bou2018geometric` of the form - (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of - the step size that range between 0 and 2 (when the mass matrix is the - identity). - - While the position (a1 = 0.5) and velocity Verlet are the most commonly used - in samplers, it is known in the numerical computation literature that the value - $a1 \approx 0.1932$ leads to a lower integration error :cite:p:`mclachlan1995numerical,schlick2010molecular`. The authors of :cite:p:`bou2018geometric` - show that the value $a1 \approx 0.21132$ leads to an even higher step acceptance - rate, up to 3 times higher than with the standard position verlet (p.22, Fig.4). - - By choosing the velocity verlet we avoid two computations of the gradient - of the kinetic energy. We are trading accuracy in exchange, and it is not - clear whether this is the right tradeoff. - - """ - a1 = 0 - b1 = 0.5 - a2 = 1 - 2 * a1 - cofficients = [b1, a2, b1] - position_update_fn = euclidean_position_update_fn(logdensity_fn) - momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) - one_step = generalized_symplectic_integrator( - momentum_update_fn, - position_update_fn, - cofficients, - format_output_fn=format_euclidean_state_output, - ) - return one_step - - -def mclachlan( - logdensity_fn: Callable, - kinetic_energy_fn: EuclideanKineticEnergy, -) -> Integrator: - """Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. - - The integrator is of the form (b1, a1, b2, a1, b1). The choice of the parameters - determine both the bound on the integration error and the stability of the - method with respect to the value of `step_size`. The values used here are - the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` is more focused on stability - and derives different values. - - """ - b1 = 0.1931833275037836 - a1 = 0.5 - b2 = 1 - 2 * b1 - cofficients = [b1, a1, b2, a1, b1] - position_update_fn = euclidean_position_update_fn(logdensity_fn) - momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) - one_step = generalized_symplectic_integrator( - momentum_update_fn, - position_update_fn, - cofficients, - format_output_fn=format_euclidean_state_output, - ) - - return one_step - - -def yoshida( - logdensity_fn: Callable, - kinetic_energy_fn: EuclideanKineticEnergy, -) -> Integrator: - """Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` - - The integrator is of the form (b1, a1, b2, a2, b2, a1, b1). The choice of - the parameters determine both the bound on the integration error and the - stability of the method with respect to the value of `step_size`. The - values used here are the ones derived in :cite:p:`mclachlan1995numerical` which guarantees a stability - interval length approximately equal to 4.67. - - """ - b1 = 0.11888010966548 - a1 = 0.29619504261126 - b2 = 0.5 - b1 - a2 = 1 - 2 * a1 - cofficients = [b1, a1, b2, a2, b2, a1, b1] - position_update_fn = euclidean_position_update_fn(logdensity_fn) - momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) - one_step = generalized_symplectic_integrator( - momentum_update_fn, - position_update_fn, - cofficients, - format_output_fn=format_euclidean_state_output, - ) - - return one_step +def generate_euclidean_integrator(cofficients): + def euclidean_integrator( + logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy + ) -> Integrator: + position_update_fn = euclidean_position_update_fn(logdensity_fn) + momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) + one_step = generalized_symplectic_integrator( + momentum_update_fn, + position_update_fn, + cofficients, + format_output_fn=format_euclidean_state_output, + ) + return one_step + + return euclidean_integrator + + +""" +The velocity Verlet (or Verlet-Störmer) integrator. + +The velocity Verlet is a two-stage palindromic integrator :cite:p:`bou2018geometric` +of the form (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of +the step size that range between 0 and 2 (when the mass matrix is the identity). + +While the position (a1 = 0.5) and velocity Verlet are the most commonly used +in samplers, it is known in the numerical computation literature that the value +$a1 \approx 0.1932$ leads to a lower integration error :cite:p:`mclachlan1995numerical,schlick2010molecular`. +The authors of :cite:p:`bou2018geometric` show that the value $a1 \approx 0.21132$ +leads to an even higher step acceptance rate, up to 3 times higher +than with the standard position verlet (p.22, Fig.4). + +By choosing the velocity verlet we avoid two computations of the gradient +of the kinetic energy. We are trading accuracy in exchange, and it is not +clear whether this is the right tradeoff. +""" +velocity_verlet_cofficients = [0.5, 1.0, 0.5] +velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients) + +""" +Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. + +The integrator is of the form (b1, a1, b2, a1, b1). The choice of the parameters +determine both the bound on the integration error and the stability of the +method with respect to the value of `step_size`. The values used here are +the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` +is more focused on stability and derives different values. +""" +b1 = 0.1931833275037836 +a1 = 0.5 +b2 = 1 - 2 * b1 +mclachlan_cofficients = [b1, a1, b2, a1, b1] +mclachlan = generate_euclidean_integrator(mclachlan_cofficients) + +""" +Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` + +The integrator is of the form (b1, a1, b2, a2, b2, a1, b1). The choice of +the parameters determine both the bound on the integration error and the +stability of the method with respect to the value of `step_size`. The +values used here are the ones derived in :cite:p:`mclachlan1995numerical` which +guarantees a stability interval length approximately equal to 4.67. +""" +b1 = 0.11888010966548 +a1 = 0.29619504261126 +b2 = 0.5 - b1 +a2 = 1 - 2 * a1 +yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1] +yoshida = generate_euclidean_integrator(yoshida_cofficients) # Intergrators with non Euclidean updates @@ -298,7 +272,8 @@ def esh_dynamics_momentum_update_one_step( [TODO]: update this docstring with proper references and citations. The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) similar to the implementation: https://github.com/gregversteeg/esh_dynamics - There are no exponentials e^delta, which prevents overflows when the gradient norm is large. + There are no exponentials e^delta, which prevents overflows when the gradient norm + is large. """ flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) @@ -343,36 +318,20 @@ def format_noneuclidean_state_output( ) -def non_euclidean_leapfrog(logdensity_fn: Callable, *args, **kwargs) -> Callable: - """Leapfrog integrator with non Euclidean updates. - - Similar update scheme as velocity_verlet, but with non Euclidean updates of the momentum. - """ - cofficients = [0.5, 1.0, 0.5] - position_update_fn = euclidean_position_update_fn(logdensity_fn) - one_step = generalized_symplectic_integrator( - esh_dynamics_momentum_update_one_step, - position_update_fn, - cofficients, - format_output_fn=format_noneuclidean_state_output, - ) - return one_step +def generate_noneuclidean_integrator(cofficients): + def noneuclidean_integrator(logdensity_fn: Callable, *args, **kwargs) -> Callable: + position_update_fn = euclidean_position_update_fn(logdensity_fn) + one_step = generalized_symplectic_integrator( + esh_dynamics_momentum_update_one_step, + position_update_fn, + cofficients, + format_output_fn=format_noneuclidean_state_output, + ) + return one_step + return noneuclidean_integrator -def minimal_norm(logdensity_fn: Callable, *args, **kwargs) -> Callable: - """minimal_norm integrator with non Euclidean updates. - Similar update scheme as mclachlan, but with non Euclidean updates of the momentum. - """ - b1 = 0.1931833275037836 - a1 = 0.5 - b2 = 1 - 2 * b1 - cofficients = [b1, a1, b2, a1, b1] - position_update_fn = euclidean_position_update_fn(logdensity_fn) - one_step = generalized_symplectic_integrator( - esh_dynamics_momentum_update_one_step, - position_update_fn, - cofficients, - format_output_fn=format_noneuclidean_state_output, - ) - return one_step +noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) +noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) +noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 56ab54458..a93d5ed2a 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -52,10 +52,13 @@ def kinetic_energy(p): "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, "non_euclidean_leapfrog": { - "algorithm": integrators.non_euclidean_leapfrog, + "algorithm": integrators.noneuclidean_leapfrog, "precision": 1e-4, }, - "minimal_norm": {"algorithm": integrators.minimal_norm, "precision": 1e-5}, + "non_euclidean_mclachlan": { + "algorithm": integrators.noneuclidean_mclachlan, + "precision": 1e-5, + }, } @@ -110,8 +113,8 @@ class IntegratorTest(chex.TestCase): "velocity_verlet", "mclachlan", "yoshida", - "non_euclidean_leapfrog", - "minimal_norm", + # "noneuclidean_leapfrog", + # "noneuclidean_mclachlan", ], ) ) From 4a514ddd6217eb1e2c22bc1fb34db29e9dbc57b7 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Nov 2023 17:41:28 +0100 Subject: [PATCH 39/78] Additional refactoring Also add test for esh momentum update. Co-authored-by: Reuben Cohn-Gordon --- blackjax/mcmc/integrators.py | 83 +++++++++++++++--------- docs/refs.bib | 18 ++++++ tests/mcmc/test_integrators.py | 112 ++++++++++++++++++++++++++------- 3 files changed, 163 insertions(+), 50 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index de07b95f6..9753e9e53 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -21,7 +21,14 @@ from blackjax.mcmc.metrics import EuclideanKineticEnergy from blackjax.types import ArrayTree -__all__ = ["mclachlan", "velocity_verlet", "yoshida"] +__all__ = [ + "mclachlan", + "velocity_verlet", + "yoshida", + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", +] class IntegratorState(NamedTuple): @@ -40,27 +47,37 @@ class IntegratorState(NamedTuple): Integrator = Callable[[IntegratorState, float], IntegratorState] -def generalized_symplectic_integrator( - momentum_update_fn: Callable, - position_update_fn: Callable, +def generalized_two_stage_integrator( + operator1: Callable, + operator2: Callable, coefficients: list[float], format_output_fn: Callable = lambda x: x, ): - """Generalized symplectic integrator. + """Generalized numerical integrator for solving ODEs. - The generalized symplectic integrator performs numerical integration - of a Hamiltonian system by alernating between momentum and position updates. - The update scheme is decided by the coefficients and palindromic, i.e. - the coefficients of the update scheme should be symmetric with respect to the + The generalized integrator performs numerical integration of a ODE system by + alernating between stage 1 and stage 2 updates. + The update scheme is decided by the coefficients, The scheme should be palindromic, + i.e. the coefficients of the update scheme should be symmetric with respect to the middle of the scheme. - [TODO]: expand this with information in https://github.com/blackjax-devs/blackjax/issues/587 + + For instance, for *any* differential equation of the form: + + .. math:: \\frac{d}{dt}f = (O_1+O_2)f + + The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` + by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`. + + In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and + :math:`e^{\\epsilon O_1}` are simple, but for other differential equations, + they may be more complex. Parameters ---------- - momentum_update_fn - Function that updates the momentum. - position_update_fn - Function that updates the position. + operator1 + Stage 1 operator, a function that updates the momentum. + operator2 + Stage 2 operator, a function that updates the position. coefficients Coefficients of the integrator. format_output_fn @@ -75,12 +92,12 @@ def generalized_symplectic_integrator( def one_step(state: IntegratorState, step_size: float): position, momentum, _, logdensity_grad = state # auxiliary infomation generated during integration for diagnostics. It is - # updated by the momentum_update_fn and position_update_fn at each call. + # updated by the operator1 and operator2 at each call. momentum_update_info = None position_update_info = None for i, coef in enumerate(coefficients[:-1]): if i % 2 == 0: - momentum, kinetic_grad, momentum_update_info = momentum_update_fn( + momentum, kinetic_grad, momentum_update_info = operator1( momentum, logdensity_grad, step_size, @@ -94,7 +111,7 @@ def one_step(state: IntegratorState, step_size: float): logdensity, logdensity_grad, position_update_info, - ) = position_update_fn( + ) = operator2( position, kinetic_grad, step_size, @@ -102,7 +119,7 @@ def one_step(state: IntegratorState, step_size: float): position_update_info, ) # Separate the last steps to short circuit the computation of the kinetic_grad. - momentum, kinetic_grad, momentum_update_info = momentum_update_fn( + momentum, kinetic_grad, momentum_update_info = operator1( momentum, logdensity_grad, step_size, @@ -189,12 +206,18 @@ def format_euclidean_state_output( def generate_euclidean_integrator(cofficients): + """Generate symplectic integrator for solving a Hamiltonian system. + + The resulting integrator is volume-preserve and preserves the symplectic structure + of phase space. + """ + def euclidean_integrator( logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy ) -> Integrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) - one_step = generalized_symplectic_integrator( + one_step = generalized_two_stage_integrator( momentum_update_fn, position_update_fn, cofficients, @@ -234,6 +257,8 @@ def euclidean_integrator( method with respect to the value of `step_size`. The values used here are the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` is more focused on stability and derives different values. + +Also known as the minimal norm integrator. """ b1 = 0.1931833275037836 a1 = 0.5 @@ -259,6 +284,11 @@ def euclidean_integrator( # Intergrators with non Euclidean updates +def normalized_flatten_array(x, tol=1e-13): + norm = jnp.sqrt(jnp.sum(jnp.square(x))) + return jnp.where(norm > tol, x / norm, x), norm + + def esh_dynamics_momentum_update_one_step( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -269,9 +299,7 @@ def esh_dynamics_momentum_update_one_step( ): """Momentum update based on Esh dynamics. - [TODO]: update this docstring with proper references and citations. - The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) - similar to the implementation: https://github.com/gregversteeg/esh_dynamics + The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` There are no exponentials e^delta, which prevents overflows when the gradient norm is large. """ @@ -279,22 +307,21 @@ def esh_dynamics_momentum_update_one_step( flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] - gradient_norm = jnp.sqrt(jnp.sum(jnp.square(flatten_grads))) - normalized_gradient = -flatten_grads / gradient_norm + normalized_gradient, gradient_norm = normalized_flatten_array(flatten_grads) momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) delta = step_size * coef * gradient_norm / (dims - 1) zeta = jnp.exp(-delta) - new_momentum = ( + new_momentum_raw = ( normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + 2 * zeta * flatten_momentum ) - new_momentum_norm = new_momentum / jnp.sqrt(jnp.sum(jnp.square(new_momentum))) + new_momentum_normalized, _ = normalized_flatten_array(new_momentum_raw) + next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta - jnp.log(2) + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) ) - next_momentum = unravel_fn(new_momentum_norm) if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change if is_last_call: @@ -321,7 +348,7 @@ def format_noneuclidean_state_output( def generate_noneuclidean_integrator(cofficients): def noneuclidean_integrator(logdensity_fn: Callable, *args, **kwargs) -> Callable: position_update_fn = euclidean_position_update_fn(logdensity_fn) - one_step = generalized_symplectic_integrator( + one_step = generalized_two_stage_integrator( esh_dynamics_momentum_update_one_step, position_update_fn, cofficients, diff --git a/docs/refs.bib b/docs/refs.bib index f5015ccb9..1b6485809 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -360,3 +360,21 @@ @inproceedings{hoffman2021adaptive year={2021}, organization={PMLR} } + +@misc{steeg2021hamiltonian, + title={Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling}, + author={Greg Ver Steeg and Aram Galstyan}, + year={2021}, + eprint={2111.02434}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} + +@misc{robnik2023microcanonical, + title={Microcanonical Hamiltonian Monte Carlo}, + author={Jakob Robnik and G. Bruno De Luca and Eva Silverstein and Uroš Seljak}, + year={2023}, + eprint={2212.08549}, + archivePrefix={arXiv}, + primaryClass={stat.CO} +} diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index a93d5ed2a..350dba73e 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -3,9 +3,13 @@ import chex import jax import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np from absl.testing import absltest, parameterized +from jax.flatten_util import ravel_pytree import blackjax.mcmc.integrators as integrators +from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0): @@ -47,20 +51,37 @@ def kinetic_energy(p): return neg_potential_energy, kinetic_energy -algorithms = { - "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, - "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, - "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, - "non_euclidean_leapfrog": { - "algorithm": integrators.noneuclidean_leapfrog, - "precision": 1e-4, - }, - "non_euclidean_mclachlan": { - "algorithm": integrators.noneuclidean_mclachlan, - "precision": 1e-5, - }, -} +def MultivariateNormal(inv_mass_matrix): + """Potential and kinetic energy for a multivariate normal distribution.""" + def log_density(q): + q, _ = ravel_pytree(q) + return stats.multivariate_normal.logpdf(q, jnp.zeros_like(q), inv_mass_matrix) + + def kinetic_energy(p): + p, _ = ravel_pytree(p) + return 0.5 * p.T @ inv_mass_matrix @ p + + return log_density, kinetic_energy + + +mvnormal_position_init = { + "a": 0.0, + "b": jnp.asarray([1.0, 2.0, 3.0]), + "c": jnp.ones((2, 1)), +} +_, unravel_fn = ravel_pytree(mvnormal_position_init) +key0, key1 = jax.random.split(jax.random.key(52)) +mvnormal_momentum_init = unravel_fn(jax.random.normal(key0, (6,))) +a = jax.random.normal(key1, (6, 6)) +cov = jnp.matmul(a.T, a) +# Validated numerically +mvnormal_position_end = unravel_fn( + jnp.asarray([0.38887993, 0.85231394, 2.7879136, 3.0339851, 0.5856687, 1.9291426]) +) +mvnormal_momentum_end = unravel_fn( + jnp.asarray([0.46576163, 0.23854092, 1.2518811, -0.35647452, -0.742138, 1.2552949]) +) examples = { "free_fall": { @@ -93,6 +114,22 @@ def kinetic_energy(p): "p_final": {"x": 0.0, "y": 1.0}, "inv_mass_matrix": jnp.array([1.0, 1.0]), }, + "multivariate_normal": { + "model": MultivariateNormal, + "num_steps": 16, + "step_size": 0.005, + "q_init": mvnormal_position_init, + "p_init": mvnormal_momentum_init, + "q_final": mvnormal_position_end, + "p_final": mvnormal_momentum_end, + "inv_mass_matrix": cov, + }, +} + +algorithms = { + "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, + "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, + "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, } @@ -108,17 +145,20 @@ class IntegratorTest(chex.TestCase): @chex.all_variants(with_pmap=False) @parameterized.parameters( itertools.product( - ["free_fall", "harmonic_oscillator", "planetary_motion"], + [ + "free_fall", + "harmonic_oscillator", + "planetary_motion", + "multivariate_normal", + ], [ "velocity_verlet", "mclachlan", "yoshida", - # "noneuclidean_leapfrog", - # "noneuclidean_mclachlan", ], ) ) - def test_integrator(self, example_name, integrator_name): + def test_euclidean_integrator(self, example_name, integrator_name): integrator = algorithms[integrator_name] example = examples[example_name] @@ -134,14 +174,11 @@ def test_integrator(self, example_name, integrator_name): initial_state = integrators.IntegratorState( q, p, neg_potential(q), jax.grad(neg_potential)(q) ) - if integrator_name in ["non_euclidean_leapfrog", "minimal_norm"]: - one_step = lambda _, state: step(state, step_size)[0] - else: - one_step = lambda _, state: step(state, step_size) + final_state = jax.lax.fori_loop( 0, example["num_steps"], - one_step, + lambda _, state: step(state, step_size), initial_state, ) @@ -155,6 +192,37 @@ def test_integrator(self, example_name, integrator_name): ) self.assertAlmostEqual(energy, new_energy, delta=integrator["precision"]) + @chex.all_variants(with_pmap=False) + @parameterized.parameters([3, 5]) + def test_esh_momentum_update(self, dims): + """ + Test the numerically efficient version of the momentum update currently + implemented match the naive implementation according to the equation in + :cite:p:`robnik2023microcanonical` + """ + step_size = 1e-3 + momentum = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(dims,)) + momentum /= jnp.linalg.norm(momentum) + gradient = jax.random.uniform(key=jax.random.PRNGKey(1), shape=(dims,)) + + # Navie implementation + gradient_norm = jnp.linalg.norm(gradient) + gradient_normalized = gradient / gradient_norm + delta = step_size * gradient_norm / (dims - 1) + next_momentum = ( + momentum + + gradient_normalized + * ( + jnp.sinh(delta) + + jnp.dot(gradient_normalized, momentum * (jnp.cosh(delta) - 1)) + ) + ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) + + # Efficient implementation + update_stable = self.variant(esh_dynamics_momentum_update_one_step) + next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) + np.testing.assert_array_almost_equal(next_momentum, next_momentum1) + if __name__ == "__main__": absltest.main() From ef1f62dd47bef407fd36589d4ad442d1d2fb348b Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Tue, 21 Nov 2023 07:30:51 +0100 Subject: [PATCH 40/78] Minor clean up. --- blackjax/mcmc/integrators.py | 13 +++++++++---- tests/mcmc/test_integrators.py | 7 ++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 9753e9e53..cc45a7e27 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -45,6 +45,9 @@ class IntegratorState(NamedTuple): Integrator = Callable[[IntegratorState, float], IntegratorState] +GeneralIntegrator = Callable[ + [IntegratorState, float], tuple[IntegratorState, ArrayTree] +] def generalized_two_stage_integrator( @@ -284,7 +287,7 @@ def euclidean_integrator( # Intergrators with non Euclidean updates -def normalized_flatten_array(x, tol=1e-13): +def _normalized_flatten_array(x, tol=1e-13): norm = jnp.sqrt(jnp.sum(jnp.square(x))) return jnp.where(norm > tol, x / norm, x), norm @@ -307,7 +310,7 @@ def esh_dynamics_momentum_update_one_step( flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] - normalized_gradient, gradient_norm = normalized_flatten_array(flatten_grads) + normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) delta = step_size * coef * gradient_norm / (dims - 1) zeta = jnp.exp(-delta) @@ -315,7 +318,7 @@ def esh_dynamics_momentum_update_one_step( normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + 2 * zeta * flatten_momentum ) - new_momentum_normalized, _ = normalized_flatten_array(new_momentum_raw) + new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -346,7 +349,9 @@ def format_noneuclidean_state_output( def generate_noneuclidean_integrator(cofficients): - def noneuclidean_integrator(logdensity_fn: Callable, *args, **kwargs) -> Callable: + def noneuclidean_integrator( + logdensity_fn: Callable, *args, **kwargs + ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( esh_dynamics_momentum_update_one_step, diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 350dba73e..a41877c13 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -197,13 +197,14 @@ def test_euclidean_integrator(self, example_name, integrator_name): def test_esh_momentum_update(self, dims): """ Test the numerically efficient version of the momentum update currently - implemented match the naive implementation according to the equation in + implemented match the naive implementation according to the Equation 16 in :cite:p:`robnik2023microcanonical` """ step_size = 1e-3 - momentum = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(dims,)) + key0, key1 = jax.random.split(jax.random.key(62)) + gradient = jax.random.uniform(key0, shape=(dims,)) + momentum = jax.random.uniform(key1, shape=(dims,)) momentum /= jnp.linalg.norm(momentum) - gradient = jax.random.uniform(key=jax.random.PRNGKey(1), shape=(dims,)) # Navie implementation gradient_norm = jnp.linalg.norm(gradient) From af43521f06201d4cbfa48ec47fe7c8c3fd216068 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Tue, 21 Nov 2023 17:09:50 +0100 Subject: [PATCH 41/78] Use standard JAX ops --- blackjax/mcmc/integrators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index cc45a7e27..e871b6211 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -288,7 +288,7 @@ def euclidean_integrator( # Intergrators with non Euclidean updates def _normalized_flatten_array(x, tol=1e-13): - norm = jnp.sqrt(jnp.sum(jnp.square(x))) + norm = jnp.linalg.norm(x) return jnp.where(norm > tol, x / norm, x), norm From 0dd419db4635fd98d7f12f203b8c8261c216ce75 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 23 Nov 2023 13:52:51 -0500 Subject: [PATCH 42/78] new integrator --- blackjax/adaptation/step_size.py | 16 +- blackjax/diagnostics.py | 2 - blackjax/mcmc/integrators.py | 594 +++++++++++++++---------------- blackjax/mcmc/mclmc.py | 27 +- 4 files changed, 306 insertions(+), 333 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 85d83b57d..b16807985 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -21,7 +21,7 @@ from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.mclmc import MCLMCState +from blackjax.mcmc.mclmc import IntegratorState from blackjax.optimizers.dual_averaging import dual_averaging from blackjax.types import Array, ArrayLikeTree, PRNGKey @@ -281,14 +281,10 @@ def ess_corr(x): ] ) - print(input_array.shape,"input shape 2") - - num_chains = 1 # input_array.shape[0] num_samples = input_array.shape[1] mean_across_chain = input_array.mean(axis=1, keepdims=True) - print("mean 2", mean_across_chain) # Compute autocovariance estimates for every lag for the input array using FFT. centered_array = input_array - mean_across_chain m = next_fast_len(2 * num_samples) @@ -305,7 +301,6 @@ def ess_corr(x): / (num_samples - 1.0) ) weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - jax.debug.print("🤯 {x} weighted_var 2 🤯", x=weighted_var) weighted_var = jax.lax.cond( num_chains > 1, @@ -377,7 +372,6 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): ess = ess_raw / tau_hat neff = ess.squeeze() / num_samples - print("tau hat", ess, num_samples, neff) return 1.0 / jnp.average(1 / neff) @@ -405,7 +399,7 @@ def dynamics_adaptive(dynamics, state, L): ) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences state, info = dynamics( - jax.random.PRNGKey(0), MCLMCState(x, u, l, g), L=L, step_size=eps + jax.random.PRNGKey(0), IntegratorState(x, u, l, g), L=L, step_size=eps ) xx, uu, ll, gg = state @@ -484,7 +478,7 @@ def step(state, outer_weight): return ( L, eps[-1], - MCLMCState(xx, uu, ll, gg), + IntegratorState(xx, uu, ll, gg), ) # return the tuned hyperparameters and the final state @@ -498,9 +492,7 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): Lfactor = 0.4 ESS2 = effective_sample_size(info.transformed_x) neff = ESS2.squeeze() / info.transformed_x.shape[0] - print("neff", neff, info.transformed_x.shape[0]) ESS_alt = 1.0 / jnp.average(1 / neff) - print(ess_corr(info.transformed_x), ESS_alt, "\n\nESSse\n\n") ESS = ess_corr(info.transformed_x) if ESS * num_steps <= 10: warnings.warn("tune3 cannot be expected to work with 10 or fewer effective samples") @@ -511,7 +503,7 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): def tune( kernel, num_steps: int, rng_key: PRNGKey, params: MCLMCAdaptationState -) -> tuple[MCLMCAdaptationState, MCLMCState]: +) -> tuple[MCLMCAdaptationState, IntegratorState]: num_tune_step_ratio_1 = 0.1 num_tune_step_ratio_2 = 0.1 diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 2f9bad5ba..da861d9b1 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -112,7 +112,6 @@ def effective_sample_size( """ input_shape = input_array.shape - print(input_shape, "input shape") sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis num_chains = input_shape[chain_axis] num_samples = input_shape[sample_axis] @@ -144,7 +143,6 @@ def effective_sample_size( lambda _: weighted_var, operand=None, ) - jax.debug.print("🤯 {x} weighted_var 🤯", x=weighted_var) # Geyer's initial positive sequence num_samples_even = num_samples - num_samples % 2 diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 1ead49745..d9e9bff3a 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -1,3 +1,4 @@ +# @title `integrators.py` from https://github.com/blackjax-devs/blackjax/pull/589 # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,16 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """Symplectic, time-reversible, integrators for Hamiltonian trajectories.""" -import functools -import itertools from typing import Callable, NamedTuple import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.mcmc.metrics import EuclideanKineticEnergy from blackjax.types import ArrayTree -__all__ = ["mclachlan", "velocity_verlet", "yoshida"] +__all__ = [ + "mclachlan", + "velocity_verlet", + "yoshida", + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", +] class IntegratorState(NamedTuple): @@ -38,343 +46,325 @@ class IntegratorState(NamedTuple): Integrator = Callable[[IntegratorState, float], IntegratorState] - - -def new_integrator_state(logdensity_fn, position, momentum): - logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return IntegratorState(position, momentum, logdensity, logdensity_grad) - - -def velocity_verlet( - logdensity_fn: Callable, - kinetic_energy_fn: EuclideanKineticEnergy, -) -> Integrator: - """The velocity Verlet (or Verlet-Störmer) integrator. - - The velocity Verlet is a two-stage palindromic integrator :cite:p:`bou2018geometric` of the form - (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of - the step size that range between 0 and 2 (when the mass matrix is the - identity). - - While the position (a1 = 0.5) and velocity Verlet are the most commonly used - in samplers, it is known in the numerical computation literature that the value - $a1 \approx 0.1932$ leads to a lower integration error :cite:p:`mclachlan1995numerical,schlick2010molecular`. The authors of :cite:p:`bou2018geometric` - show that the value $a1 \approx 0.21132$ leads to an even higher step acceptance - rate, up to 3 times higher than with the standard position verlet (p.22, Fig.4). - - By choosing the velocity verlet we avoid two computations of the gradient - of the kinetic energy. We are trading accuracy in exchange, and it is not - clear whether this is the right tradeoff. - +GeneralIntegrator = Callable[ + [IntegratorState, float], tuple[IntegratorState, ArrayTree] +] + + +def generalized_two_stage_integrator( + operator1: Callable, + operator2: Callable, + coefficients: list[float], + format_output_fn: Callable = lambda x: x, +): + """Generalized numerical integrator for solving ODEs. + + The generalized integrator performs numerical integration of a ODE system by + alernating between stage 1 and stage 2 updates. + The update scheme is decided by the coefficients, The scheme should be palindromic, + i.e. the coefficients of the update scheme should be symmetric with respect to the + middle of the scheme. + + For instance, for *any* differential equation of the form: + + .. math:: \\frac{d}{dt}f = (O_1+O_2)f + + The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` + by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`. + + In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and + :math:`e^{\\epsilon O_1}` are simple, but for other differential equations, + they may be more complex. + + Parameters + ---------- + operator1 + Stage 1 operator, a function that updates the momentum. + operator2 + Stage 2 operator, a function that updates the position. + coefficients + Coefficients of the integrator. + format_output_fn + Function that formats the output of the integrator. + + Returns + ------- + integrator + Integrator function. """ - a1 = 0 - b1 = 0.5 - a2 = 1 - 2 * a1 - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: + def one_step(state: IntegratorState, step_size: float): position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, + # auxiliary infomation generated during integration for diagnostics. It is + # updated by the operator1 and operator2 at each call. + momentum_update_info = None + position_update_info = None + for i, coef in enumerate(coefficients[:-1]): + if i % 2 == 0: + momentum, kinetic_grad, momentum_update_info = operator1( + momentum, + logdensity_grad, + step_size, + coef, + momentum_update_info, + is_last_call=False, + ) + else: + ( + position, + logdensity, + logdensity_grad, + position_update_info, + ) = operator2( + position, + kinetic_grad, + step_size, + coef, + position_update_info, + ) + # Separate the last steps to short circuit the computation of the kinetic_grad. + momentum, kinetic_grad, momentum_update_info = operator1( momentum, logdensity_grad, + step_size, + coefficients[-1], + momentum_update_info, + is_last_call=True, ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a2 * step_size * kinetic_grad, + return format_output_fn( position, - kinetic_grad, - ) - - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, momentum, + logdensity, logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, ) - return IntegratorState(position, momentum, logdensity, logdensity_grad) - return one_step -def mclachlan( - logdensity_fn: Callable, - kinetic_energy_fn: Callable, -) -> Integrator: - """Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. - - The integrator is of the form (b1, a1, b2, a1, b1). The choice of the parameters - determine both the bound on the integration error and the stability of the - method with respect to the value of `step_size`. The values used here are - the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` is more focused on stability - and derives different values. +def new_integrator_state(logdensity_fn, position, momentum): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return IntegratorState(position, momentum, logdensity, logdensity_grad) - """ - b1 = 0.1932 - a1 = 0.5 - b2 = 1 - 2 * b1 +def euclidean_position_update_fn(logdensity_fn: Callable): logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, + def update( + position: ArrayTree, + kinetic_grad: ArrayTree, + step_size: float, + coef: float, + auxiliary_info=None, + ): + del auxiliary_info + new_position = jax.tree_util.tree_map( + lambda x, grad: x + step_size * coef * grad, position, kinetic_grad, ) + logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) + return new_position, logdensity, logdensity_grad, None - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) + return update - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, +def euclidean_momentum_update_fn(kinetic_energy_fn: EuclideanKineticEnergy): + kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) + + def update( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + auxiliary_info=None, + is_last_call=False, + ): + del auxiliary_info + new_momentum = jax.tree_util.tree_map( + lambda x, grad: x + step_size * coef * grad, momentum, logdensity_grad, ) + if is_last_call: + return new_momentum, None, None + kinetic_grad = kinetic_energy_grad_fn(new_momentum) + return new_momentum, kinetic_grad, None - return IntegratorState(position, momentum, logdensity, logdensity_grad) + return update - return one_step +def format_euclidean_state_output( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, +): + del kinetic_grad, position_update_info, momentum_update_info + return IntegratorState(position, momentum, logdensity, logdensity_grad) -def yoshida( - logdensity_fn: Callable, - kinetic_energy_fn: Callable, -) -> Integrator: - """Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` - The integrator is of the form (b1, a1, b2, a2, b2, a1, b1). The choice of - the parameters determine both the bound on the integration error and the - stability of the method with respect to the value of `step_size`. The - values used here are the ones derived in :cite:p:`mclachlan1995numerical` which guarantees a stability - interval length approximately equal to 4.67. +def generate_euclidean_integrator(cofficients): + """Generate symplectic integrator for solving a Hamiltonian system. + The resulting integrator is volume-preserve and preserves the symplectic structure + of phase space. """ - b1 = 0.11888010966548 - a1 = 0.29619504261126 - b2 = 0.5 - b1 - a2 = 1 - 2 * a1 - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a2 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, + def euclidean_integrator( + logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy + ) -> Integrator: + position_update_fn = euclidean_position_update_fn(logdensity_fn) + momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) + one_step = generalized_two_stage_integrator( + momentum_update_fn, + position_update_fn, + cofficients, + format_output_fn=format_euclidean_state_output, ) + return one_step + + return euclidean_integrator + + +""" +The velocity Verlet (or Verlet-Störmer) integrator. + +The velocity Verlet is a two-stage palindromic integrator :cite:p:`bou2018geometric` +of the form (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of +the step size that range between 0 and 2 (when the mass matrix is the identity). + +While the position (a1 = 0.5) and velocity Verlet are the most commonly used +in samplers, it is known in the numerical computation literature that the value +$a1 \approx 0.1932$ leads to a lower integration error :cite:p:`mclachlan1995numerical,schlick2010molecular`. +The authors of :cite:p:`bou2018geometric` show that the value $a1 \approx 0.21132$ +leads to an even higher step acceptance rate, up to 3 times higher +than with the standard position verlet (p.22, Fig.4). + +By choosing the velocity verlet we avoid two computations of the gradient +of the kinetic energy. We are trading accuracy in exchange, and it is not +clear whether this is the right tradeoff. +""" +velocity_verlet_cofficients = [0.5, 1.0, 0.5] +velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients) + +""" +Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. + +The integrator is of the form (b1, a1, b2, a1, b1). The choice of the parameters +determine both the bound on the integration error and the stability of the +method with respect to the value of `step_size`. The values used here are +the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` +is more focused on stability and derives different values. + +Also known as the minimal norm integrator. +""" +b1 = 0.1931833275037836 +a1 = 0.5 +b2 = 1 - 2 * b1 +mclachlan_cofficients = [b1, a1, b2, a1, b1] +mclachlan = generate_euclidean_integrator(mclachlan_cofficients) + +""" +Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` + +The integrator is of the form (b1, a1, b2, a2, b2, a1, b1). The choice of +the parameters determine both the bound on the integration error and the +stability of the method with respect to the value of `step_size`. The +values used here are the ones derived in :cite:p:`mclachlan1995numerical` which +guarantees a stability interval length approximately equal to 4.67. +""" +b1 = 0.11888010966548 +a1 = 0.29619504261126 +b2 = 0.5 - b1 +a2 = 1 - 2 * a1 +yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1] +yoshida = generate_euclidean_integrator(yoshida_cofficients) + + +# Intergrators with non Euclidean updates +def _normalized_flatten_array(x, tol=1e-13): + norm = jnp.linalg.norm(x) + return jnp.where(norm > tol, x / norm, x), norm + + +def esh_dynamics_momentum_update_one_step( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + previous_kinetic_energy_change=None, + is_last_call=False, +): + """Momentum update based on Esh dynamics. + + The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` + There are no exponentials e^delta, which prevents overflows when the gradient norm + is large. + """ - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum_raw = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) + next_momentum = unravel_fn(new_momentum_normalized) + kinetic_energy_change = ( + delta + - jnp.log(2) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) + if previous_kinetic_energy_change is not None: + kinetic_energy_change += previous_kinetic_energy_change + if is_last_call: + kinetic_energy_change *= dims - 1 + return next_momentum, next_momentum, kinetic_energy_change + + +def format_noneuclidean_state_output( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, +): + del kinetic_grad, position_update_info + return ( + IntegratorState(position, momentum, logdensity, logdensity_grad), + momentum_update_info, + ) + + +def generate_noneuclidean_integrator(cofficients): + def noneuclidean_integrator( + logdensity_fn: Callable, *args, **kwargs + ) -> GeneralIntegrator: + position_update_fn = euclidean_position_update_fn(logdensity_fn) + one_step = generalized_two_stage_integrator( + esh_dynamics_momentum_update_one_step, + position_update_fn, + cofficients, + format_output_fn=format_noneuclidean_state_output, ) + return one_step - return IntegratorState(position, momentum, logdensity, logdensity_grad) - - return one_step - - -def palindromic_sequence(s): - # symetrize - s = s[:-1] + s[::-1] - # zip with alternating operators - return lambda O1, O2 : list(zip(itertools.cycle([O1, O2]), s)) - -def make_integrator(O1, O2, order): + return noneuclidean_integrator - sequence = order(O1,O2) - - def step(state: IntegratorState, step_size): - """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" - - xx, uu, ll, gg = state - total_r = 0 - - for O,factor in sequence: - xx, uu, ll, gg, r = jax.tree_util.tree_map(O(step_size*factor), xx, uu, ll, gg) - total_r += r - - kinetic_change = jax.numpy.sum(total_r) - return IntegratorState(xx, uu, ll, gg), kinetic_change - - return step - -def update_position_mclmc(grad_logp): - """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" - - def update(step_size, x, u, l, g): - xx = x + step_size * u - ll, gg = grad_logp(xx) - return xx, u, ll, gg, 0 - - return lambda O : functools.partial(update,O) - - -def update_momentum_mclmc(step_size): - """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) - similar to the implementation: https://github.com/gregversteeg/esh_dynamics - There are no exponentials e^delta, which prevents overflows when the gradient norm is large. - """ - - def update(x, u, l, g): - g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) - e = g / g_norm - ue = jax.numpy.dot(u, e) - # jax.debug.print("🤯 {x} inside momentum update 🤯", x=(ue)) - dim = u.shape[0] - delta = step_size * g_norm / (dim - 1) - zeta = jax.numpy.exp(-delta) - uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u - delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) - return x, uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), l, g, delta_r - return update -minimal_norm = lambda O1, O2: make_integrator(O1, O2, minimal_norm_sequence) - -minimal_norm_sequence = palindromic_sequence([0.1931833275037836, 0.5, 1.- 2.*0.1931833275037836]) -leapfrog_sequence = palindromic_sequence([0.5, 1.]) -yoshida_sequence = palindromic_sequence([0.11888010966548, 0.29619504261126, 0.5 - 0.11888010966548, 1 - 2 * 0.29619504261126]) - - -# def minimal_norm(O1, O2): -# lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator - -# def step(state: IntegratorState, step_size): -# """Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" - -# # V T V T V -# # jax.debug.print("🤯 {x} inside integrator 1 🤯", x=(state.momentum, state.logdensity_grad)) -# uu, r1 = jax.tree_util.tree_map( -# lambda x, u, g: O1(step_size * lambda_c)(x, u, g), -# state.position, -# state.momentum, -# state.logdensity_grad, -# ) -# # jax.debug.print("🤯 {x} inside integrator 2 🤯", x=(uu)) - -# xx, _, ll, gg = jax.tree_util.tree_map( -# lambda x, u: O2(step_size, x, u), state.position, uu*0.5 -# ) -# uu, r2 = jax.tree_util.tree_map( -# lambda x, u, g: O1(step_size * (1 - 2 * lambda_c))(x, u, g), xx, uu, gg -# ) -# xx, _, ll, gg = jax.tree_util.tree_map( -# lambda x, u: O2(step_size, x, u), xx, uu*0.5 -# ) -# uu, r3 = jax.tree_util.tree_map( -# lambda x, u, g: O1(step_size * lambda_c)(x, u, g), xx, uu, gg -# ) - -# # kinetic energy change -# kinetic_change = (r1 + r2 + r3) * (uu.shape[0] - 1) - -# return xx, uu, ll, gg, kinetic_change - -# return step - - -# def update_position_mclmc(grad_logp): -# """The position updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)""" - -# def update(step_size, x, u): -# xx = x + step_size * u -# ll, gg = grad_logp(xx) -# return xx, u, ll, gg - -# return update - - -# def update_momentum_mclmc(step_size): -# """The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) -# similar to the implementation: https://github.com/gregversteeg/esh_dynamics -# There are no exponentials e^delta, which prevents overflows when the gradient norm is large. -# """ - -# def update(x, u, g): -# g_norm = jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(g))) -# e = g / g_norm -# ue = jax.numpy.dot(u, e) -# # jax.debug.print("🤯 {x} inside momentum update 🤯", x=(ue)) -# dim = u.shape[0] -# delta = step_size * g_norm / (dim - 1) -# zeta = jax.numpy.exp(-delta) -# uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u -# delta_r = delta - jax.numpy.log(2) + jax.numpy.log(1 + ue + (1 - ue) * zeta**2) -# return uu / jax.numpy.sqrt(jax.numpy.sum(jax.numpy.square(uu))), delta_r -# return update \ No newline at end of file +noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) +noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) +noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 49d428949..aa7a16a38 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -17,13 +17,11 @@ import jax import jax.numpy as jnp -import blackjax.mcmc.integrators as integrators +from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLike, PRNGKey -__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] - -MCLMCState = integrators.IntegratorState +__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] class MCLMCInfo(NamedTuple): @@ -46,9 +44,8 @@ class MCLMCInfo(NamedTuple): def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) - # jax.debug.print("🤯 {x} initial momentum 🤯", x=random_unit_vector(rng_key, dim=x_initial.shape[0])) - return MCLMCState( + return IntegratorState( position=x_initial, momentum=random_unit_vector(rng_key, dim=x_initial.shape[0]), logdensity=l, @@ -56,7 +53,7 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(grad_logp, integrator, transform): +def build_kernel(logdensity_fn, integrator, transform): """Build a HMC kernel. Parameters @@ -77,14 +74,11 @@ def build_kernel(grad_logp, integrator, transform): information about the transition. """ - step = integrator( - O1=integrators.update_momentum_mclmc, - O2=integrators.update_position_mclmc(grad_logp), - ) + step = integrator(logdensity_fn) def kernel( - rng_key: PRNGKey, state: MCLMCState, L: float, step_size: float - ) -> tuple[MCLMCState, MCLMCInfo]: + rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float + ) -> tuple[IntegratorState, MCLMCInfo]: (xx, uu, ll, gg), kinetic_change = step(state, step_size) dim = xx.shape[0] @@ -92,7 +86,7 @@ def kernel( nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) uu = partially_refresh_momentum(u=uu, rng_key=rng_key, nu=nu) - return MCLMCState(xx, uu, ll, gg), MCLMCInfo( + return IntegratorState(xx, uu, ll, gg), MCLMCInfo( transformed_x=transform(xx), logdensity=ll, dE=kinetic_change - ll + state.logdensity, @@ -161,11 +155,10 @@ def __new__( # type: ignore[misc] transform: Callable, L, step_size, - integrator=integrators.minimal_norm, + integrator=noneuclidean_mclachlan, ) -> SamplingAlgorithm: - grad_logp = jax.value_and_grad(logdensity_fn) - kernel = cls.build_kernel(grad_logp, integrator, transform) + kernel = cls.build_kernel(logdensity_fn, integrator, transform) def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) From 0c8330e83b279fa8a3d49d345ef252014e76e120 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 23 Nov 2023 13:54:39 -0500 Subject: [PATCH 43/78] add references --- docs/refs.bib | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/docs/refs.bib b/docs/refs.bib index f5015ccb9..689bc0f6d 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -360,3 +360,28 @@ @inproceedings{hoffman2021adaptive year={2021}, organization={PMLR} } + +@article{ver2021hamiltonian, + title={Hamiltonian dynamics with non-newtonian momentum for rapid sampling}, + author={Ver Steeg, Greg and Galstyan, Aram}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={11012--11025}, + year={2021} +} + +@article{robnik2023microcanonical, + title={Microcanonical Hamiltonian Monte Carlo}, + author={Robnik, Jakob and De Luca, G Bruno and Silverstein, Eva and Seljak, Uro{\v{s}}}, + journal={Journal of Machine Learning Research}, + volume={24}, + pages={1--34}, + year={2023} +} + +@article{robnik2023microcanonical, + title={Microcanonical Langevin Monte Carlo}, + author={Robnik, Jakob and Seljak, Uro{\v{s}}}, + journal={arXiv preprint arXiv:2303.18221}, + year={2023} +} \ No newline at end of file From 40fc61c9334843b71eab74d95ea8c731b194dd6b Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 Nov 2023 18:01:01 -0500 Subject: [PATCH 44/78] flake --- blackjax/adaptation/step_size.py | 19 +++++++++-------- blackjax/mcmc/mclmc.py | 35 ++++++++++++++++---------------- docs/refs.bib | 2 +- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index b16807985..3016980a5 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """Step size adaptation""" -from typing import Callable, NamedTuple import warnings +from typing import Callable, NamedTuple import jax import jax.numpy as jnp from scipy.fft import next_fast_len -from blackjax.diagnostics import effective_sample_size +from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.hmc import HMCState from blackjax.mcmc.mclmc import IntegratorState from blackjax.optimizers.dual_averaging import dual_averaging -from blackjax.types import Array, ArrayLikeTree, PRNGKey +from blackjax.types import PRNGKey __all__ = [ "DualAveragingAdaptationState", @@ -270,7 +270,6 @@ class MCLMCAdaptationState(NamedTuple): step_size: float - def ess_corr(x): """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html shape(x) = (num_samples, d)""" @@ -301,7 +300,7 @@ def ess_corr(x): / (num_samples - 1.0) ) weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - + weighted_var = jax.lax.cond( num_chains > 1, lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), @@ -492,11 +491,13 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): Lfactor = 0.4 ESS2 = effective_sample_size(info.transformed_x) neff = ESS2.squeeze() / info.transformed_x.shape[0] - ESS_alt = 1.0 / jnp.average(1 / neff) - ESS = ess_corr(info.transformed_x) + # ESS_alt = 1.0 / jnp.average(1 / neff) + ESS = ess_corr(info.transformed_x) if ESS * num_steps <= 10: - warnings.warn("tune3 cannot be expected to work with 10 or fewer effective samples") - + warnings.warn( + "tune3 cannot be expected to work with 10 or fewer effective samples" + ) + Lnew = Lfactor * eps / ESS return Lnew, state diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index aa7a16a38..057ee57f5 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -17,8 +17,8 @@ import jax import jax.numpy as jnp -from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -79,18 +79,18 @@ def build_kernel(logdensity_fn, integrator, transform): def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: - (xx, uu, ll, gg), kinetic_change = step(state, step_size) - - dim = xx.shape[0] + (position, momentum, logdensity, logdensitygrad), kinetic_change = step(state, step_size) + + dim = position.shape[0] # Langevin-like noise nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) - uu = partially_refresh_momentum(u=uu, rng_key=rng_key, nu=nu) + momentum = partially_refresh_momentum(momentum=momentum, rng_key=rng_key, nu=nu) - return IntegratorState(xx, uu, ll, gg), MCLMCInfo( - transformed_x=transform(xx), - logdensity=ll, - dE=kinetic_change - ll + state.logdensity, - kinetic_change=kinetic_change*(uu.shape[0] - 1), + return IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( + transformed_x=transform(position), + logdensity=logdensity, + dE=kinetic_change - logdensity + state.logdensity, + kinetic_change=kinetic_change * (dim - 1), ) return kernel @@ -157,7 +157,6 @@ def __new__( # type: ignore[misc] step_size, integrator=noneuclidean_mclachlan, ) -> SamplingAlgorithm: - kernel = cls.build_kernel(logdensity_fn, integrator, transform) def update_fn(rng_key, state): @@ -175,12 +174,12 @@ def init_fn(position: ArrayLike): def random_unit_vector(rng_key, dim): - u = jax.random.normal(rng_key, shape=(dim,)) - u /= jnp.sqrt(jnp.sum(jnp.square(u))) - return u + momentum = jax.random.normal(rng_key, shape=(dim,)) + momentum /= jnp.sqrt(jnp.sum(jnp.square(momentum))) + return momentum -def partially_refresh_momentum(u, rng_key, nu): - """Adds a small noise to u and normalizes.""" - z = nu * jax.random.normal(rng_key, shape=(u.shape[0],)) - return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) +def partially_refresh_momentum(momentum, rng_key, nu): + """Adds a small noise to momentum and normalizes.""" + z = nu * jax.random.normal(rng_key, shape=(momentum.shape[0],)) + return (momentum + z) / jnp.sqrt(jnp.sum(jnp.square(momentum + z))) diff --git a/docs/refs.bib b/docs/refs.bib index fcd5cf8f7..acbb207e3 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -384,4 +384,4 @@ @misc{robnik2023microcanonical2 author={Robnik, Jakob and Seljak, Uro{\v{s}}}, journal={arXiv preprint arXiv:2303.18221}, year={2023} -} \ No newline at end of file +} From 6ea53203dca603a2c64ca1ac7d998362a591ec3e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 Nov 2023 18:16:00 -0500 Subject: [PATCH 45/78] temporarily add 'explore' --- explore.py | 87 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/explore.py b/explore.py index c19beb638..1cddbf820 100644 --- a/explore.py +++ b/explore.py @@ -1,18 +1,21 @@ import math from typing import NamedTuple -from chex import Array + import jax import jax.numpy as jnp +from chex import Array from scipy.fftpack import next_fast_len # type: ignore import blackjax from blackjax.adaptation.step_size import MCLMCAdaptationState, tune from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.integrators import minimal_norm -from blackjax.mcmc.mclmc import MCLMCState, build_kernel +from blackjax.mcmc.integrators import noneuclidean_mclachlan +from blackjax.mcmc.mclmc import build_kernel + # from blackjax.diagnostics import effective_sample_size from blackjax.types import PRNGKey + def logdensity_fn(x): return -0.5 * jnp.sum(jnp.square(x)) @@ -22,41 +25,67 @@ def run_sampling_algorithm( ): keys = jax.random.split(rng_key, num_steps) state = sampling_algorithm.init(initial_val) - _, info = jax.lax.scan( - lambda s, k: (sampling_algorithm.step(k, s)), state, keys - ) + _, info = jax.lax.scan(lambda s, k: (sampling_algorithm.step(k, s)), state, keys) return info +def tune_and_run(logdensity_fn, key, dim, num_steps): + main_key, tune_key = jax.random.split(key) + identity = lambda x: x -key = jax.random.PRNGKey(0) -main_key, tune_key = jax.random.split(key) - -num_steps = 10000 -dim = 2 + params, state = tune( + params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), + kernel=build_kernel( + logdensity_fn, integrator=noneuclidean_mclachlan, transform=identity + ), + num_steps=num_steps, + rng_key=tune_key, + ) -params, state = tune( - params = MCLMCAdaptationState(L = math.sqrt(dim), step_size = math.sqrt(dim) * 0.4), - kernel=build_kernel(grad_logp=jax.value_and_grad(logdensity_fn), - integrator=minimal_norm, transform=lambda x:x), num_steps=num_steps, rng_key=tune_key) + mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + transform=lambda x: x, + L=params.L, + step_size=params.step_size, + ) -mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - transform=lambda x: x, - L=params.L, step_size=params.step_size -) + print( + f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}" + ) + return run_sampling_algorithm( + sampling_algorithm=mclmc, + num_steps=num_steps, + initial_val=state.position, + rng_key=main_key, + ) -out = run_sampling_algorithm( - sampling_algorithm=mclmc, - num_steps=num_steps, - initial_val=state.position, - rng_key=main_key, -) +out = tune_and_run(logdensity_fn=logdensity_fn, key=jax.random.PRNGKey(0), dim=2, num_steps=10000) print(jnp.mean(out.transformed_x, axis=0)) -print(f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}") -assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 -assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) +# assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 +# assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) +assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([0.00296992, 0.00087555])) + + + +# def test_mclmc(self): +# """Test the MCLMC kernel.""" +# init_key0, init_key1, inference_key = jax.random.split(self.key, 3) +# x_data = jax.random.normal(init_key0, shape=(1000, 1)) +# y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + +# logposterior_fn_ = functools.partial( +# self.regression_logprob, x=x_data, preds=y_data +# ) +# logposterior_fn = lambda x: logposterior_fn_(**x) + +# mala = blackjax.mcmc.mclmc.mclmc(logposterior_fn, 1e-5) +# state = mala.init({"coefs": 1.0, "log_scale": 1.0}) +# states = inference_loop(mala.step, 10_000, inference_key, state) +# coefs_samples = states.position["coefs"][3000:] +# scale_samples = np.exp(states.position["log_scale"][3000:]) +# np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) +# np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) \ No newline at end of file From c83dc1a98d8c12b3a481f8abf9fe37dfc2f85ce3 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 25 Nov 2023 17:19:57 -0500 Subject: [PATCH 46/78] temporarily add 'explore' --- explore.py | 1 + 1 file changed, 1 insertion(+) diff --git a/explore.py b/explore.py index 1cddbf820..ec4f4a3ae 100644 --- a/explore.py +++ b/explore.py @@ -63,6 +63,7 @@ def tune_and_run(logdensity_fn, key, dim, num_steps): print(jnp.mean(out.transformed_x, axis=0)) + # assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 # assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([0.00296992, 0.00087555])) From c8b43be2867729a6005c191b08e9444405954637 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sun, 26 Nov 2023 21:37:31 +0100 Subject: [PATCH 47/78] Adding a test for energy preservation. Co-authored-by: Reuben Cohn-Gordon --- blackjax/util.py | 18 +++++++++++++++ tests/mcmc/test_integrators.py | 40 ++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/blackjax/util.py b/blackjax/util.py index 1a7ebcd09..c4a840c21 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -82,6 +82,24 @@ def generate_gaussian_noise( return unravel_fn(mu + linear_map(sigma, sample)) +def generate_unit_vector(rng_key: PRNGKey, position: ArrayLikeTree,) -> Array: + """Generate a random unit vector with output structure that match a given PyTree. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + position: + PyTree that the structure the output should to match. + + Returns + ------- + Random unit vector that match the structure of position. + """ + p, unravel_fn = ravel_pytree(position) + sample = normal(rng_key, shape=p.shape, dtype=p.dtype) + return unravel_fn(sample / jnp.linalg.norm(sample)) + def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index a41877c13..127079312 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -10,6 +10,7 @@ import blackjax.mcmc.integrators as integrators from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step +from blackjax.util import generate_unit_vector def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0): @@ -130,6 +131,9 @@ def kinetic_energy(p): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "noneuclidean_leapfrog" : {"algorithm": integrators.noneuclidean_leapfrog}, + "noneuclidean_mclachlan" : {"algorithm": integrators.noneuclidean_mclachlan}, + "noneuclidean_yoshida" : {"algorithm": integrators.noneuclidean_yoshida}, } @@ -224,6 +228,42 @@ def test_esh_momentum_update(self, dims): next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) + @chex.all_variants(with_pmap=False) + @parameterized.parameters( + [ + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", + ], + ) + def test_noneuclidean_integrator(self, integrator_name): + integrator = algorithms[integrator_name] + cov = jnp.asarray([[1., .5] , [.5, 2.]]) + logdensity_fn = lambda x: stats.multivariate_normal.logpdf( + x, jnp.zeros([2]), cov) + + step = self.variant(integrator["algorithm"](logdensity_fn)) + + rng = jax.random.key(4263456) + key0, key1 = jax.random.split(rng, 2) + position_init = jax.random.normal(key0, (2,)) + momentum_init = generate_unit_vector(key1, position_init) + step_size = .0001 + initial_state = integrators.new_integrator_state( + logdensity_fn, position_init, momentum_init) + + final_state, kinetic_energy_change = jax.lax.scan( + lambda state, _: step(state, step_size), + initial_state, + xs=None, + length=15, + ) + + # Check the conservation of energy. + potential_energy_change = final_state.logdensity - initial_state.logdensity + energy_change = kinetic_energy_change[-1] + potential_energy_change + self.assertAlmostEqual(energy_change, 0, delta=1e-3) + if __name__ == "__main__": absltest.main() From 88942487daa046e2c9811be74b86b66db89ec3c3 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sun, 26 Nov 2023 21:40:21 +0100 Subject: [PATCH 48/78] fix formatting --- blackjax/util.py | 6 +++++- tests/mcmc/test_integrators.py | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index c4a840c21..a3a7226a6 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -82,7 +82,10 @@ def generate_gaussian_noise( return unravel_fn(mu + linear_map(sigma, sample)) -def generate_unit_vector(rng_key: PRNGKey, position: ArrayLikeTree,) -> Array: +def generate_unit_vector( + rng_key: PRNGKey, + position: ArrayLikeTree, +) -> Array: """Generate a random unit vector with output structure that match a given PyTree. Parameters @@ -100,6 +103,7 @@ def generate_unit_vector(rng_key: PRNGKey, position: ArrayLikeTree,) -> Array: sample = normal(rng_key, shape=p.shape, dtype=p.dtype) return unravel_fn(sample / jnp.linalg.norm(sample)) + def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 127079312..2f5020d00 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -131,9 +131,9 @@ def kinetic_energy(p): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, - "noneuclidean_leapfrog" : {"algorithm": integrators.noneuclidean_leapfrog}, - "noneuclidean_mclachlan" : {"algorithm": integrators.noneuclidean_mclachlan}, - "noneuclidean_yoshida" : {"algorithm": integrators.noneuclidean_yoshida}, + "noneuclidean_leapfrog": {"algorithm": integrators.noneuclidean_leapfrog}, + "noneuclidean_mclachlan": {"algorithm": integrators.noneuclidean_mclachlan}, + "noneuclidean_yoshida": {"algorithm": integrators.noneuclidean_yoshida}, } @@ -230,17 +230,18 @@ def test_esh_momentum_update(self, dims): @chex.all_variants(with_pmap=False) @parameterized.parameters( - [ - "noneuclidean_leapfrog", - "noneuclidean_mclachlan", - "noneuclidean_yoshida", - ], + [ + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", + ], ) def test_noneuclidean_integrator(self, integrator_name): integrator = algorithms[integrator_name] - cov = jnp.asarray([[1., .5] , [.5, 2.]]) + cov = jnp.asarray([[1.0, 0.5], [0.5, 2.0]]) logdensity_fn = lambda x: stats.multivariate_normal.logpdf( - x, jnp.zeros([2]), cov) + x, jnp.zeros([2]), cov + ) step = self.variant(integrator["algorithm"](logdensity_fn)) @@ -248,9 +249,10 @@ def test_noneuclidean_integrator(self, integrator_name): key0, key1 = jax.random.split(rng, 2) position_init = jax.random.normal(key0, (2,)) momentum_init = generate_unit_vector(key1, position_init) - step_size = .0001 + step_size = 0.0001 initial_state = integrators.new_integrator_state( - logdensity_fn, position_init, momentum_init) + logdensity_fn, position_init, momentum_init + ) final_state, kinetic_energy_change = jax.lax.scan( lambda state, _: step(state, step_size), From 9865145ac301455e677534162f06125fdc5ca92b Mon Sep 17 00:00:00 2001 From: = Date: Sun, 26 Nov 2023 22:03:45 +0100 Subject: [PATCH 49/78] wip: tests --- blackjax/adaptation/step_size.py | 27 +++++++++++++------- explore.py | 42 +++++++++----------------------- tests/mcmc/test_sampling.py | 21 ++++++++++++++++ 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 3016980a5..4a495a7a8 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -21,7 +21,8 @@ from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.mclmc import IntegratorState +from blackjax.mcmc.integrators import noneuclidean_mclachlan +from blackjax.mcmc.mclmc import IntegratorState, build_kernel, init from blackjax.optimizers.dual_averaging import dual_averaging from blackjax.types import PRNGKey @@ -503,19 +504,27 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): def tune( - kernel, num_steps: int, rng_key: PRNGKey, params: MCLMCAdaptationState + position, logdensity_fn, num_steps: int, rng_key: PRNGKey, params: MCLMCAdaptationState ) -> tuple[MCLMCAdaptationState, IntegratorState]: num_tune_step_ratio_1 = 0.1 num_tune_step_ratio_2 = 0.1 - x, u, l, g = ( - jnp.array([0.1, 0.1]), - jnp.array([-0.6755803, 0.73728645]), - -0.010000001, - -jnp.array([0.1, 0.1]), - ) + kernel=build_kernel( + logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x:x + ) + + + + init_key, tune1_key, tune2_key = jax.random.split(rng_key, 3) + + x,u,l,g = init(position, logdensity_fn=logdensity_fn, rng_key=init_key) + # x, u, l, g = ( + # jnp.array([0.1, 0.1]), + # jnp.array([-0.6755803, 0.73728645]), + # -0.010000001, + # -jnp.array([0.1, 0.1]), + # ) - tune1_key, tune2_key = jax.random.split(rng_key) L, eps, state = tune12( kernel, diff --git a/explore.py b/explore.py index ec4f4a3ae..35f7b4850 100644 --- a/explore.py +++ b/explore.py @@ -1,9 +1,12 @@ +import functools import math from typing import NamedTuple import jax import jax.numpy as jnp from chex import Array +import numpy as np +from scipy import stats from scipy.fftpack import next_fast_len # type: ignore import blackjax @@ -14,6 +17,7 @@ # from blackjax.diagnostics import effective_sample_size from blackjax.types import PRNGKey +# from tests.mcmc.test_sampling import inference_loop def logdensity_fn(x): @@ -28,15 +32,13 @@ def run_sampling_algorithm( _, info = jax.lax.scan(lambda s, k: (sampling_algorithm.step(k, s)), state, keys) return info -def tune_and_run(logdensity_fn, key, dim, num_steps): +def tune_and_run(position, logdensity_fn, key, dim, num_steps): main_key, tune_key = jax.random.split(key) - identity = lambda x: x params, state = tune( + position=position, params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), - kernel=build_kernel( - logdensity_fn, integrator=noneuclidean_mclachlan, transform=identity - ), + logdensity_fn=logdensity_fn, num_steps=num_steps, rng_key=tune_key, ) @@ -58,35 +60,13 @@ def tune_and_run(logdensity_fn, key, dim, num_steps): rng_key=main_key, ) -out = tune_and_run(logdensity_fn=logdensity_fn, key=jax.random.PRNGKey(0), dim=2, num_steps=10000) +out = tune_and_run(position=jnp.array([10.0, 10.0]), logdensity_fn=logdensity_fn, key=jax.random.PRNGKey(0), dim=2, num_steps=10000) print(jnp.mean(out.transformed_x, axis=0)) -# assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 -# assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) -assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([0.00296992, 0.00087555])) - - - -# def test_mclmc(self): -# """Test the MCLMC kernel.""" -# init_key0, init_key1, inference_key = jax.random.split(self.key, 3) -# x_data = jax.random.normal(init_key0, shape=(1000, 1)) -# y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) - -# logposterior_fn_ = functools.partial( -# self.regression_logprob, x=x_data, preds=y_data -# ) -# logposterior_fn = lambda x: logposterior_fn_(**x) - -# mala = blackjax.mcmc.mclmc.mclmc(logposterior_fn, 1e-5) -# state = mala.init({"coefs": 1.0, "log_scale": 1.0}) -# states = inference_loop(mala.step, 10_000, inference_key, state) - -# coefs_samples = states.position["coefs"][3000:] -# scale_samples = np.exp(states.position["log_scale"][3000:]) +# # assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 +# # assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) +# assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([0.00296992, 0.00087555])) -# np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) -# np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) \ No newline at end of file diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7770b55a1..59560f505 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,6 +13,7 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk +from explore import tune_and_run def inference_loop(kernel, num_samples, rng_key, initial_state): @@ -142,6 +143,26 @@ def test_mala(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + + def test_mclmc(self): + """Test the MCLMC kernel.""" + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = tune_and_run(position={"coefs": 1.0, "log_scale": 1.0}, logdensity_fn=logdensity_fn, key=inference_key, dim=2, num_steps=10000) + + + coefs_samples = states.transformed_x["coefs"][3000:] + scale_samples = np.exp(states.transformed_x["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( From 0c614128b65986e33c2e7a1486f61868f44d2e57 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 26 Nov 2023 23:29:54 +0100 Subject: [PATCH 50/78] use pytrees for partially_refresh_momentum, and add test --- blackjax/mcmc/mclmc.py | 26 +++++++------------------- blackjax/util.py | 25 +++++++++++++++++++++++++ explore.py | 29 ++++++++++++++++------------- 3 files changed, 48 insertions(+), 32 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 057ee57f5..a18d6e337 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -16,10 +16,12 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey +from blackjax.util import generate_unit_vector, partially_refresh_momentum __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -47,7 +49,7 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): return IntegratorState( position=x_initial, - momentum=random_unit_vector(rng_key, dim=x_initial.shape[0]), + momentum=generate_unit_vector(rng_key, x_initial), logdensity=l, logdensity_grad=g, ) @@ -81,10 +83,11 @@ def kernel( ) -> tuple[IntegratorState, MCLMCInfo]: (position, momentum, logdensity, logdensitygrad), kinetic_change = step(state, step_size) - dim = position.shape[0] + # dim = position.shape[0] + dim = 2 # Langevin-like noise - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) - momentum = partially_refresh_momentum(momentum=momentum, rng_key=rng_key, nu=nu) + + momentum, dim = partially_refresh_momentum(momentum=momentum, rng_key=rng_key, L=L, step_size=step_size) return IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( transformed_x=transform(position), @@ -168,18 +171,3 @@ def init_fn(position: ArrayLike): return SamplingAlgorithm(init_fn, update_fn) -### -# helper funcs -### - - -def random_unit_vector(rng_key, dim): - momentum = jax.random.normal(rng_key, shape=(dim,)) - momentum /= jnp.sqrt(jnp.sum(jnp.square(momentum))) - return momentum - - -def partially_refresh_momentum(momentum, rng_key, nu): - """Adds a small noise to momentum and normalizes.""" - z = nu * jax.random.normal(rng_key, shape=(momentum.shape[0],)) - return (momentum + z) / jnp.sqrt(jnp.sum(jnp.square(momentum + z))) diff --git a/blackjax/util.py b/blackjax/util.py index a3a7226a6..b3f419c0a 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -103,6 +103,31 @@ def generate_unit_vector( sample = normal(rng_key, shape=p.shape, dtype=p.dtype) return unravel_fn(sample / jnp.linalg.norm(sample)) +def partially_refresh_momentum(momentum, rng_key, step_size, L): + """Adds a small noise to momentum and normalizes. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + momentum: + PyTree that the structure the output should to match. + step_size: + Step size + L: + controls rate of momentum change + + Returns + ------- + momentum with random change in angle + """ + m, unravel_fn = ravel_pytree(momentum) + dim = m.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + return unravel_fn((m + z) / jnp.sqrt(jnp.sum(jnp.square(m + z)))), dim + + def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" diff --git a/explore.py b/explore.py index 35f7b4850..dfda2179b 100644 --- a/explore.py +++ b/explore.py @@ -35,28 +35,31 @@ def run_sampling_algorithm( def tune_and_run(position, logdensity_fn, key, dim, num_steps): main_key, tune_key = jax.random.split(key) - params, state = tune( - position=position, - params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), - logdensity_fn=logdensity_fn, - num_steps=num_steps, - rng_key=tune_key, - ) + # params, state = tune( + # position=position, + # params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), + # logdensity_fn=logdensity_fn, + # num_steps=num_steps, + # rng_key=tune_key, + # ) + # print( + # f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}" + # ) mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, transform=lambda x: x, - L=params.L, - step_size=params.step_size, + # L=params.L, + # step_size=params.step_size, + L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4 ) - print( - f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}" - ) + return run_sampling_algorithm( sampling_algorithm=mclmc, num_steps=num_steps, - initial_val=state.position, + # initial_val=state.position, + initial_val=position, rng_key=main_key, ) From be07631717d61aec39f2deda11b1bf7a0c43d51c Mon Sep 17 00:00:00 2001 From: = Date: Mon, 27 Nov 2023 15:28:56 +0100 Subject: [PATCH 51/78] update docstring --- blackjax/mcmc/mclmc.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index a18d6e337..7dd5a52ec 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,9 +15,6 @@ from typing import Callable, NamedTuple import jax -import jax.numpy as jnp -from jax.flatten_util import ravel_pytree - from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey @@ -27,15 +24,17 @@ class MCLMCInfo(NamedTuple): - """Additional information on the MCLMC transition. - - transformed_x - The value of the samples after a transformation (e.g. projection onto lower dim subspace) - logdensity - logdensity at given step - dE - energy difference + """ + Additional information on the MCLMC transition. + Attributes + ---------- + transformed_x : + The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace. + logdensity : + The log-density of the distribution at the current step of the MCLMC chain. + dE : + The difference in energy between the current and previous step. """ transformed_x: Array From a170d0be3cb52e6a54c96bec5f687169ac1658c8 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 27 Nov 2023 15:32:05 +0100 Subject: [PATCH 52/78] remove 'explore' --- explore.py | 75 ------------------------------------------------------ 1 file changed, 75 deletions(-) delete mode 100644 explore.py diff --git a/explore.py b/explore.py deleted file mode 100644 index dfda2179b..000000000 --- a/explore.py +++ /dev/null @@ -1,75 +0,0 @@ -import functools -import math -from typing import NamedTuple - -import jax -import jax.numpy as jnp -from chex import Array -import numpy as np -from scipy import stats -from scipy.fftpack import next_fast_len # type: ignore - -import blackjax -from blackjax.adaptation.step_size import MCLMCAdaptationState, tune -from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.integrators import noneuclidean_mclachlan -from blackjax.mcmc.mclmc import build_kernel - -# from blackjax.diagnostics import effective_sample_size -from blackjax.types import PRNGKey -# from tests.mcmc.test_sampling import inference_loop - - -def logdensity_fn(x): - return -0.5 * jnp.sum(jnp.square(x)) - - -def run_sampling_algorithm( - sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key -): - keys = jax.random.split(rng_key, num_steps) - state = sampling_algorithm.init(initial_val) - _, info = jax.lax.scan(lambda s, k: (sampling_algorithm.step(k, s)), state, keys) - return info - -def tune_and_run(position, logdensity_fn, key, dim, num_steps): - main_key, tune_key = jax.random.split(key) - - # params, state = tune( - # position=position, - # params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), - # logdensity_fn=logdensity_fn, - # num_steps=num_steps, - # rng_key=tune_key, - # ) - # print( - # f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}" - # ) - - mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - transform=lambda x: x, - # L=params.L, - # step_size=params.step_size, - L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4 - ) - - - return run_sampling_algorithm( - sampling_algorithm=mclmc, - num_steps=num_steps, - # initial_val=state.position, - initial_val=position, - rng_key=main_key, - ) - -out = tune_and_run(position=jnp.array([10.0, 10.0]), logdensity_fn=logdensity_fn, key=jax.random.PRNGKey(0), dim=2, num_steps=10000) - -print(jnp.mean(out.transformed_x, axis=0)) - - - -# # assert params.L==1.3147894144058228 and params.step_size==0.6470216512680054 -# # assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([1.9507202e-03, 2.8414153e-05])) -# assert jnp.allclose(jnp.mean(out.transformed_x, axis=0), jnp.array([0.00296992, 0.00087555])) - From 8cfb75f4df080e7f642d22e827aaf92af5363cff Mon Sep 17 00:00:00 2001 From: = Date: Mon, 27 Nov 2023 15:37:05 +0100 Subject: [PATCH 53/78] fix pre-commit --- blackjax/adaptation/step_size.py | 22 ++++++------ blackjax/mcmc/mclmc.py | 19 ++++++---- blackjax/util.py | 2 +- tests/mcmc/test_sampling.py | 62 +++++++++++++++++++++++--------- 4 files changed, 69 insertions(+), 36 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 4a495a7a8..b2d7326cc 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -19,7 +19,6 @@ import jax.numpy as jnp from scipy.fft import next_fast_len -from blackjax.diagnostics import effective_sample_size from blackjax.mcmc.hmc import HMCState from blackjax.mcmc.integrators import noneuclidean_mclachlan from blackjax.mcmc.mclmc import IntegratorState, build_kernel, init @@ -490,8 +489,8 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): ) Lfactor = 0.4 - ESS2 = effective_sample_size(info.transformed_x) - neff = ESS2.squeeze() / info.transformed_x.shape[0] + # ESS2 = effective_sample_size(info.transformed_x) + # neff = ESS2.squeeze() / info.transformed_x.shape[0] # ESS_alt = 1.0 / jnp.average(1 / neff) ESS = ess_corr(info.transformed_x) if ESS * num_steps <= 10: @@ -504,20 +503,22 @@ def tune3(kernel, state, rng_key, L, eps, num_steps): def tune( - position, logdensity_fn, num_steps: int, rng_key: PRNGKey, params: MCLMCAdaptationState + position, + logdensity_fn, + num_steps: int, + rng_key: PRNGKey, + params: MCLMCAdaptationState, ) -> tuple[MCLMCAdaptationState, IntegratorState]: num_tune_step_ratio_1 = 0.1 num_tune_step_ratio_2 = 0.1 - kernel=build_kernel( - logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x:x - ) - + kernel = build_kernel( + logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x: x + ) - init_key, tune1_key, tune2_key = jax.random.split(rng_key, 3) - x,u,l,g = init(position, logdensity_fn=logdensity_fn, rng_key=init_key) + x, u, l, g = init(position, logdensity_fn=logdensity_fn, rng_key=init_key) # x, u, l, g = ( # jnp.array([0.1, 0.1]), # jnp.array([-0.6755803, 0.73728645]), @@ -525,7 +526,6 @@ def tune( # -jnp.array([0.1, 0.1]), # ) - L, eps, state = tune12( kernel, x, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 7dd5a52ec..3700ee821 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,6 +15,7 @@ from typing import Callable, NamedTuple import jax + from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey @@ -33,7 +34,7 @@ class MCLMCInfo(NamedTuple): The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace. logdensity : The log-density of the distribution at the current step of the MCLMC chain. - dE : + dE : The difference in energy between the current and previous step. """ @@ -80,15 +81,21 @@ def build_kernel(logdensity_fn, integrator, transform): def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: - (position, momentum, logdensity, logdensitygrad), kinetic_change = step(state, step_size) + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( + state, step_size + ) # dim = position.shape[0] dim = 2 # Langevin-like noise - - momentum, dim = partially_refresh_momentum(momentum=momentum, rng_key=rng_key, L=L, step_size=step_size) - return IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( + momentum, dim = partially_refresh_momentum( + momentum=momentum, rng_key=rng_key, L=L, step_size=step_size + ) + + return IntegratorState( + position, momentum, logdensity, logdensitygrad + ), MCLMCInfo( transformed_x=transform(position), logdensity=logdensity, dE=kinetic_change - logdensity + state.logdensity, @@ -168,5 +175,3 @@ def init_fn(position: ArrayLike): return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) return SamplingAlgorithm(init_fn, update_fn) - - diff --git a/blackjax/util.py b/blackjax/util.py index b3f419c0a..1c764137d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -103,6 +103,7 @@ def generate_unit_vector( sample = normal(rng_key, shape=p.shape, dtype=p.dtype) return unravel_fn(sample / jnp.linalg.norm(sample)) + def partially_refresh_momentum(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. @@ -128,7 +129,6 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): return unravel_fn((m + z) / jnp.sqrt(jnp.sum(jnp.square(m + z)))), dim - def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 59560f505..ea43cf56c 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,7 +13,6 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk -from explore import tune_and_run def inference_loop(kernel, num_samples, rng_key, initial_state): @@ -85,6 +84,36 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) + # def tune_and_run(position, logdensity_fn, key, dim, num_steps): + # main_key, tune_key = jax.random.split(key) + + # # params, state = tune( + # # position=position, + # # params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), + # # logdensity_fn=logdensity_fn, + # # num_steps=num_steps, + # # rng_key=tune_key, + # # ) + # # print( + # # f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}" + # # ) + + # mclmc = blackjax.mcmc.mclmc.mclmc( + # logdensity_fn=logdensity_fn, + # transform=lambda x: x, + # # L=params.L, + # # step_size=params.step_size, + # L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4 + # ) + + # return run_sampling_algorithm( + # sampling_algorithm=mclmc, + # num_steps=num_steps, + # # initial_val=state.position, + # initial_val=position, + # rng_key=main_key, + # ) + @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) def test_window_adaptation(self, case, is_mass_matrix_diagonal): """Test the HMC kernel and the Stan warmup.""" @@ -143,26 +172,25 @@ def test_mala(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) - - def test_mclmc(self): - """Test the MCLMC kernel.""" - init_key0, init_key1, inference_key = jax.random.split(self.key, 3) - x_data = jax.random.normal(init_key0, shape=(1000, 1)) - y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) - logposterior_fn_ = functools.partial( - self.regression_logprob, x=x_data, preds=y_data - ) - logdensity_fn = lambda x: logposterior_fn_(**x) + # def test_mclmc(self): + # """Test the MCLMC kernel.""" + # init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + # x_data = jax.random.normal(init_key0, shape=(1000, 1)) + # y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) - states = tune_and_run(position={"coefs": 1.0, "log_scale": 1.0}, logdensity_fn=logdensity_fn, key=inference_key, dim=2, num_steps=10000) + # logposterior_fn_ = functools.partial( + # self.regression_logprob, x=x_data, preds=y_data + # ) + # logdensity_fn = lambda x: logposterior_fn_(**x) - - coefs_samples = states.transformed_x["coefs"][3000:] - scale_samples = np.exp(states.transformed_x["log_scale"][3000:]) + # states = tune_and_run(position={"coefs": 1.0, "log_scale": 1.0}, logdensity_fn=logdensity_fn, key=inference_key, dim=2, num_steps=10000) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + # coefs_samples = states.transformed_x["coefs"][3000:] + # scale_samples = np.exp(states.transformed_x["log_scale"][3000:]) + + # np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + # np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( From b42e77eec763222d700b1c6a8f03fbcd470d1f50 Mon Sep 17 00:00:00 2001 From: "jakob.robnik" Date: Wed, 29 Nov 2023 19:17:04 +0100 Subject: [PATCH 54/78] adding randomized MCHMC --- blackjax/explore.py | 55 -------------- blackjax/mcmc/mclmc.py | 26 ++----- blackjax/mcmc/rmchmc.py | 155 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 75 deletions(-) delete mode 100644 blackjax/explore.py create mode 100644 blackjax/mcmc/rmchmc.py diff --git a/blackjax/explore.py b/blackjax/explore.py deleted file mode 100644 index bb07779d4..000000000 --- a/blackjax/explore.py +++ /dev/null @@ -1,55 +0,0 @@ -import jax -import jax.numpy as jnp -import jax.scipy.stats as stats -import numpy as np -import sys -import blackjax -from blackjax.base import SamplingAlgorithm -import jax -import jax.numpy as jnp -import jax.scipy.stats as stats -import numpy as np -import sys -import blackjax -from blackjax.mcmc.mclmc import Parameters - -def logdensity_fn(x): - return -0.5*jnp.sum(jnp.square(x-5)) - -# Build the kernel -inverse_mass_matrix = jnp.array([1.0, 1.0]) - -# Initialize the state -initial_position = jnp.array([1.0, 1.0]) - - -mclmc = blackjax.mcmc.mclmc.mclmc( - logdensity_fn=logdensity_fn, - d=2, - transform=lambda x: x, - init_key=jax.random.PRNGKey(0), - params=Parameters(0.56568545, 1.4142135, inverse_mass_matrix), -) - -# ? -# tuning() - -flip = lambda f: lambda s, k: f(k, s) - -def run_sampling_algorithm( - sampling_algorithm: SamplingAlgorithm, num_steps: int, initial_val, rng_key -): - state = sampling_algorithm.init(initial_val) - keys = jax.random.split(rng_key, num_steps) - _, info = jax.lax.scan(flip(sampling_algorithm.step), state, keys) - return info - -out = run_sampling_algorithm( - sampling_algorithm=mclmc, - num_steps=10000, - initial_val=jnp.array([0.1, 0.1]), - rng_key=jax.random.PRNGKey(0), -) - -print(jnp.mean(out.transformed_x, axis=0)) - diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index cf3bd8c63..a3732f04e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -64,12 +64,12 @@ def init(x_initial : ArrayLikeTree, logdensity_fn, random_key): grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) l, g = grad_nlogp(x_initial) - u = random_unit_vector(random_key, d=x_initial.shape[0]) + u = full_refresh(random_key, d=x_initial.shape[0]) return MCLMCState(x_initial, u, l, g) -def random_unit_vector(random_key,d): +def full_refresh(random_key, d): u = jax.random.normal(jax.random.PRNGKey(0), shape = (d, )) u /= jnp.sqrt(jnp.sum(jnp.square(u))) return u @@ -102,33 +102,23 @@ def update(eps, u, g): return update -def partially_refresh_momentum(d, sequential= True): +def partial_refresh(d): """Adds a small noise to u and normalizes.""" - - def rng_sequential(u, random_key, nu): + def rng(u, random_key, nu): z = nu * jax.random.normal(random_key, shape = (d, )) return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) - -# def rng_parallel(u, random_key, nu): -# key, subkey = jax.random.split(random_key) -# noise = nu * jax.random.normal(subkey, shape= u.shape, dtype=u.dtype) - -# return (u + noise) / jnp.sqrt(jnp.sum(jnp.square(u + noise), axis = 1))[:, None], key + return rng - return rng_sequential - def update(hamiltonian_dynamics, partially_refresh_momentum, d): -# print("BAR 4") def step(x, u, g, random_key, L, eps, sigma): """One step of the generalized dynamics.""" # Hamiltonian step - # print("BAR 3") xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, eps=eps, sigma = sigma) # Langevin-like noise @@ -146,15 +136,11 @@ def build_kernel(grad_nlogp, d, integrator, transform, params): hamiltonian_step, _ = integrator(T= update_position(grad_nlogp), V= update_momentum(d), d= d) - # print("BAR") - move = update(hamiltonian_step, partially_refresh_momentum(d), d) - # print("BAZ") + move = update(hamiltonian_step, partial_refresh(d), d) def kernel(rng_key : PRNGKey, state : MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: x, u, l, g = state - - xx, uu, ll, gg, kinetic_change = move(x, u, g, rng_key, L, eps, sigma) de = kinetic_change + ll - l return MCLMCState(xx, uu, ll, gg), MCLMCInfo(transform(xx), ll, de) diff --git a/blackjax/mcmc/rmchmc.py b/blackjax/mcmc/rmchmc.py new file mode 100644 index 000000000..907c1d97d --- /dev/null +++ b/blackjax/mcmc/rmchmc.py @@ -0,0 +1,155 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the MCLMC Kernel""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +from blackjax.base import SamplingAlgorithm +from blackjax.types import Array, ArrayLikeTree, PRNGKey + +__all__ = ["RMCHMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] + + +from mclmc import Parameters, MCLMCInfo, full_refresh, update_position, update_momentum, minimal_norm + + +class RMCHMCState(NamedTuple): + """State of the MCLMC algorithm.""" + + t: float # time step (0., 1., 2., ....) + x: Array # location in the sampling space + l: float # - log p(x) + g: Array # - grad log p(x) + + +def init(x_initial : ArrayLikeTree, logdensity_fn): + + grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) + l, g = grad_nlogp(x_initial) + + return RMCHMCState(0., x_initial, l, g) + + + + +def halton(t, max_bits=10): + """for t= 0., 1., 2., ... it outputs halton sequence at that index (0.5, 0.25, 0.75, ...) + taken from: https://github.com/tensorflow/probability/blob/main/discussion/snaper_hmc/SNAPER-HMC.ipynb""" + float_index = jnp.asarray(t) + bit_masks = 2**jnp.arange(max_bits, dtype=float_index.dtype) + return jnp.einsum('i,i->', jnp.mod((float_index + 1) // bit_masks, 2), 0.5 / bit_masks) + + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu -1) + x = k * (mu - 0.5 *(k+1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton(t) * s) + + + +def proposal(hamiltonian_step, d): + + def prop(t, x, g, random_key, L, eps, sigma): + + #jiter the number of steps + num_steps = jnp.rint(2 * halton(t) * L / eps).astype(int) + + #full momentum refreshment + u = full_refresh(random_key, d) + + # do num_steps of the Hamiltonian dynamics + + def body(i, state): + + x, u, l, g, kinetic_energy = state + xx, uu, ll, gg, kinetic_change = hamiltonian_step(x=x, u=u, g=g, eps=eps, sigma = sigma) + + return xx, uu, ll, gg, kinetic_energy + kinetic_change + + xx, uu, ll, gg, kinetic_change = jax.fori_loop(0, num_steps, body, (x, u, 0., g, 0.)) + + return xx, ll, gg, kinetic_change + + return prop + + +def build_kernel(grad_nlogp, d, integrator, transform, params): + + L, eps, sigma = params + + hamiltonian_step, _ = integrator(T= update_position(grad_nlogp), V= update_momentum(d), d= d) + get_proposal = proposal(hamiltonian_step, d) + + def kernel(rng_key : PRNGKey, state : RMCHMCState) -> tuple[RMCHMCState, MCLMCInfo]: + + key1, key2 = jax.random.split(rng_key) + + t, x, l, g = state + xx, ll, gg, kinetic_change = get_proposal(t, x, g, key1, L, eps, sigma) + de = kinetic_change + ll - l + + # accept/reject + + acc_prob = jnp.clip(jnp.exp(-de), 0, 1) + accept = jax.random.bernoulli(key2, acc_prob) + xx, ll, gg = jax.tree_util.tree_map(lambda new, old: jax.lax.select(accept, new, old), (xx, ll, gg), (x, l, g)) + + return RMCHMCState(t + 1., xx, ll, gg), MCLMCInfo(transform(xx), ll, de) + + return kernel + + +class rmchmc: + """todo: add documentation""" + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + d : int, + transform : Callable, + params : Parameters, + *, + integrator = minimal_norm, + ) -> SamplingAlgorithm: + + grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) + + kernel = cls.build_kernel(grad_nlogp, d, integrator, transform, params) + + def init_fn(position: ArrayLikeTree): + return cls.init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + ) + + return SamplingAlgorithm(init_fn, step_fn) + From 2b323ce234c7f7ca3313a4ab1dc232267aebbd5e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 1 Dec 2023 18:10:15 +0100 Subject: [PATCH 55/78] wip checkpoint on tuning --- blackjax/adaptation/mclmc_adaptation.py | 427 ++++++++++++++++++++++++ blackjax/adaptation/step_size.py | 281 ---------------- blackjax/mcmc/integrators.py | 3 +- blackjax/mcmc/mclmc.py | 9 +- blackjax/util.py | 13 + 5 files changed, 450 insertions(+), 283 deletions(-) create mode 100644 blackjax/adaptation/mclmc_adaptation.py diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py new file mode 100644 index 000000000..4ef964a44 --- /dev/null +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -0,0 +1,427 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L. + +""" + +from typing import NamedTuple +import warnings + +from chex import PRNGKey +import jax +import jax.numpy as jnp +from scipy.fftpack import next_fast_len #type: ignore + +from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan +from blackjax.mcmc.mclmc import build_kernel, init + + +class MCLMCAdaptationState(NamedTuple): + """Tunable parameters for MCLMC""" + + L: float + step_size: float + + +def ess_corr(x): + """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html + shape(x) = (num_samples, d)""" + + input_array = jnp.array( + [ + x, + ] + ) + + num_chains = 1 # input_array.shape[0] + num_samples = input_array.shape[1] + + mean_across_chain = input_array.mean(axis=1, keepdims=True) + # Compute autocovariance estimates for every lag for the input array using FFT. + centered_array = input_array - mean_across_chain + m = next_fast_len(2 * num_samples) + ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) + ifft_ary *= jnp.conjugate(ifft_ary) + autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) + autocov_value = ( + jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples + ) + mean_autocov_var = autocov_value.mean(0, keepdims=True) + mean_var0 = ( + jnp.take(mean_autocov_var, jnp.array([0]), axis=1) + * num_samples + / (num_samples - 1.0) + ) + weighted_var = mean_var0 * (num_samples - 1.0) / num_samples + + weighted_var = jax.lax.cond( + num_chains > 1, + lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), + lambda _: weighted_var, + operand=None, + ) + + # Geyer's initial positive sequence + num_samples_even = num_samples - num_samples % 2 + mean_autocov_var_tp1 = jnp.take( + mean_autocov_var, jnp.arange(1, num_samples_even), axis=1 + ) + rho_hat = jnp.concatenate( + [ + jnp.ones_like(mean_var0), + 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, + ], + axis=1, + ) + + rho_hat = jnp.moveaxis(rho_hat, 1, 0) + rho_hat_even = rho_hat[0::2] + rho_hat_odd = rho_hat[1::2] + + mask0 = (rho_hat_even + rho_hat_odd) > 0.0 + carry_cond = jnp.ones_like(mask0[0]) + max_t = jnp.zeros_like(mask0[0], dtype=int) + + def positive_sequence_body_fn(state, mask_t): + t, carry_cond, max_t = state + next_mask = carry_cond & mask_t + next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) + return (t + 1, next_mask, next_max_t), next_mask + + (*_, max_t_next), mask = jax.lax.scan( + positive_sequence_body_fn, (0, carry_cond, max_t), mask0 + ) + indices = jnp.indices(max_t_next.shape) + indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) + rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) + # improve estimation + mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) + rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) + + # Geyer's initial monotone sequence + def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): + update_mask = rho_hat_sum_t > rho_hat_sum_tm1 + next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) + return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) + + rho_hat_sum = rho_hat_even + rho_hat_odd + _, (update_mask, update_value) = jax.lax.scan( + monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum + ) + + rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) + rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) + + # compute effective sample size + ess_raw = num_chains * num_samples + tau_hat = ( + -1.0 + + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) + - rho_hat_even_final[indices] + ) + + tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) + ess = ess_raw / tau_hat + + neff = ess.squeeze() / num_samples + return 1.0 / jnp.average(1 / neff) + + +# def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): +# """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + +# nonans = jnp.all(jnp.isfinite(xx)) +# return nonans, *jax.tree_util.tree_map( +# lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), +# (xx, uu, ll, gg, eps_max, dK), +# (x, u, l, g, eps * 0.8, 0.0), +# ) + + +# def dynamics_adaptive(dynamics, state, L): +# """One step of the dynamics with the adaptive stepsize""" + +# x, u, l, g, E, Feps, Weps, eps_max, key = state + +# eps = jnp.power( +# Feps / Weps, -1.0 / 6.0 +# ) # We use the Var[E] = O(eps^6) relation here. +# eps = (eps < eps_max) * eps + ( +# eps > eps_max +# ) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences + +# state, info = dynamics( +# jax.random.PRNGKey(0), IntegratorState(x, u, l, g), L=L, step_size=eps +# ) + +# xx, uu, ll, gg = state +# # ll, gg = -ll, -gg +# kinetic_change = info.kinetic_change + +# varEwanted = 5e-4 +# sigma_xi = 1.5 +# neff = 150 # effective number of steps used to determine the stepsize in the adaptive step +# gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step + +# # step updating +# success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject( +# x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change +# ) + +# DE = info.dE # energy difference +# EE = E + DE # energy +# # Warning: var = 0 if there were nans, but we will give it a very small weight +# xi = ( +# (DE**2) / (xx.shape[0] * varEwanted) +# ) + 1e-8 # 1e-8 is added to avoid divergences in log xi +# w = jnp.exp( +# -0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi)) +# ) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. +# Feps = gamma * Feps + w * ( +# xi / jnp.power(eps, 6.0) +# ) # Kalman update the linear combinations +# Weps = gamma * Weps + w + +# return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success + + +# def tune12(kernel, x, u, l, g, random_key, L, eps, num_steps1, num_steps2): +# """cheap hyperparameter tuning""" + +# def step(state, outer_weight): +# """one adaptive step of the dynamics""" +# x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive( +# kernel, state[0], L +# ) +# W, F1, F2 = state[1] +# w = outer_weight * eps +# zero_prevention = 1 - outer_weight +# F1 = (W * F1 + w * x) / ( +# W + w + zero_prevention +# ) # Update with a Kalman filter +# F2 = (W * F2 + w * jnp.square(x)) / ( +# W + w + zero_prevention +# ) # Update with a Kalman filter +# W += w + +# return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps + +# # we use the last num_steps2 to compute the diagonal preconditioner +# outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + +# # initial state +# state = ( +# (x, u, l, g, 0.0, jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), +# (0.0, jnp.zeros(len(x)), jnp.zeros(len(x))), +# ) +# # run the steps +# state, eps = jax.lax.scan( +# step, init=state, xs=outer_weights, length=num_steps1 + num_steps2 +# ) +# # determine L +# if num_steps2 != 0.0: +# F1, F2 = state[1][1], state[1][2] +# variances = F2 - jnp.square(F1) +# sigma2 = jnp.average(variances) + +# L = jnp.sqrt(sigma2 * x.shape[0]) + +# xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state +# return ( +# L, +# eps[-1], +# IntegratorState(xx, uu, ll, gg), +# ) # return the tuned hyperparameters and the final state + +# adapt_L_on_ess +def adapt_L_on_ess(kernel, state, rng_key, params, num_steps): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + state, info = jax.lax.scan( + lambda s, k: (kernel(k, s, params.L, params.step_size)), state, jax.random.split(rng_key, num_steps) + ) + + Lfactor = 0.4 + # ESS2 = effective_sample_size(info.transformed_x) + # neff = ESS2.squeeze() / info.transformed_x.shape[0] + # ESS_alt = 1.0 / jnp.average(1 / neff) + ESS = ess_corr(info.transformed_x) + if ESS * num_steps <= 10: + warnings.warn( + "tune3 cannot be expected to work with 10 or fewer effective samples" + ) + + Lnew = Lfactor * params.step_size / ESS + return Lnew, state + +# def adapt_L_step_size(kernel, state, rng_key, params, num_steps1, num_steps2): + +def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + + nonans = jnp.all(jnp.isfinite(xx)) + _x, _u, _l, _g, _eps, _dk = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (xx, uu, ll, gg, eps_max, dK), + (x, u, l, g, eps * 0.8, 0.)) + + return nonans, _x, _u, _l, _g, _eps, _dk + +def adapt_L_step_size(dynamics, d, frac, + varEwanted = 1e-3, sigma_xi = 1.5, neff = 150): + + print("Starting tune12 (blackjax)") + gamma_forget = (neff - 1.0) / (neff + 1.0) + + + def predictor(dyn_old, hyp, adaptive_state): + """does one step with the dynamics and updates the prediction for the optimal stepsize + Designed for the unadjusted MCHMC""" + + W, F, eps_max = adaptive_state + + # dynamics + dyn_new, energy_change = dynamics(dyn_old, hyp) + + # step updating + success, x, u, l, g, eps_max, energy_change = nan_reject(dyn_old.x, dyn_old.u, dyn_old.l, dyn_old.g, + dyn_new.x, dyn_new.u, dyn_new.l, dyn_new.g, + hyp.eps, eps_max, energy_change) + + dyn = State(x, u, l, g, dyn_new.key) + + # Warning: var = 0 if there were nans, but we will give it a very small weight + xi = (jnp.square(energy_change) / (d * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi + w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. + + F = gamma_forget * F + w * (xi/jnp.power(hyp.eps, 6.0)) + W = gamma_forget * W + w + eps = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. + eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences + hyp_new = Hyperparameters(hyp.L, eps, hyp.sigma) + + return dyn, hyp_new, hyp_new, (W, F, eps_max), success + + + def update_kalman(x, state, outer_weight, success, eps): + """kalman filter to estimate the size of the posterior""" + W, F1, F2 = state + w = outer_weight * eps * success + zero_prevention = 1-outer_weight + F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter + F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter + W += w + return (W, F1, F2) + + + adap0 = (0., 0., jnp.inf) + _step = predictor + + + def step(state, outer_weight): + """does one step of the dynamcis and updates the estimate of the posterior size and optimal stepsize""" + dyn, hyp, _, adaptive_state, kalman_state = state + dyn, hyp, hyp_final, adaptive_state, success = _step(dyn, hyp, adaptive_state) + kalman_state = update_kalman(dyn.x, kalman_state, outer_weight, success, hyp.eps) + + return (dyn, hyp, hyp_final, adaptive_state, kalman_state), None + + + def func(_dyn, _hyp, num_steps): + + num_steps1, num_steps2 = jnp.rint(num_steps * frac).astype(int) + + # we use the last num_steps2 to compute the diagonal preconditioner + outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + #initial state + + kalman_state = (0., jnp.zeros(d), jnp.zeros(d)) + + # run the steps + state = jax.lax.scan(step, init= (_dyn, _hyp, _hyp, adap0, kalman_state), xs= outer_weights, length= num_steps1 + num_steps2)[0] + dyn, _, hyp, adap, kalman_state = state + + L = hyp.L + sigma = hyp.sigma + # determine L + if num_steps2 != 0.: + _, F1, F2 = kalman_state + variances = F2 - jnp.square(F1) + L = jnp.sqrt(jnp.sum(variances)) + + # optionally we do the diagonal preconditioning (and readjust the stepsize) + if diag_precond: + + # diagonal preconditioning + sigma = jnp.sqrt(variances) + L = jnp.sqrt(d) + + #readjust the stepsize + steps = num_steps2 // 3 #we do some small number of steps + state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] + dyn, _, hyp, adap, kalman_state = state + else: + sigma = hyp.sigma + + jax.debug.print(" \n\n\nPARAMS:\n{x}", x=(dyn,Hyperparameters(L, hyp.eps, sigma) )) + return dyn, Hyperparameters(L, hyp.eps, sigma) + + return func + + +def mclmc_find_L_and_step_size( + position, + logdensity_fn, + num_steps: int, + rng_key: PRNGKey, + params: MCLMCAdaptationState, +) -> tuple[MCLMCAdaptationState, IntegratorState]: + num_tune_step_ratio_1 = 0.1 + num_tune_step_ratio_2 = 0.1 + num_tune_step_ratio_3 = 0.1 + + kernel = build_kernel( + logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x: x + ) + + init_key, tune1_key, tune2_key = jax.random.split(rng_key, 3) + + state = init(position, logdensity_fn=logdensity_fn, rng_key=init_key) + # x, u, l, g = ( + # jnp.array([0.1, 0.1]), + # jnp.array([-0.6755803, 0.73728645]), + # -0.010000001, + # -jnp.array([0.1, 0.1]), + # ) + + # jax.debug.print("eq {x}", x=(state.momentum,jnp.array([ 0.92340476, -0.38382764])) ) + # jax.debug.print("eq {x}", x=(state.position,jnp.array([ 1.,1.])) ) + + # position=state,IntegratorState(jnp.array([1., 1.]), jnp.array([ 0.92340476, -0.38382764]), jnp.array(1.), jnp.array([1., 1.]))) + + # MCLMCAdaptationState(L=jax.Array(1.41421356), eps=jax.Array(0.35355339, weak_type=True), sigma=jax.Array([1., 1.])) + + varEwanted = 5e-4 + d = state.position.shape[0] + + + params, state = adapt_L_step_size(kernel, d, jnp.array([num_tune_step_ratio_1, num_tune_step_ratio_2]), varEwanted, 1.5, 150)(state, params, num_steps) + + + L, step_size = params + + L, state = adapt_L_on_ess( + kernel, state, tune2_key, params, int(num_steps * num_tune_step_ratio_3) + ) + return MCLMCAdaptationState(L, step_size), state diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index b2d7326cc..6fbc7e924 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -262,284 +262,3 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: return rss_state.step_size - -class MCLMCAdaptationState(NamedTuple): - """Tunable parameters for MCLMC""" - - L: float - step_size: float - - -def ess_corr(x): - """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html - shape(x) = (num_samples, d)""" - - input_array = jnp.array( - [ - x, - ] - ) - - num_chains = 1 # input_array.shape[0] - num_samples = input_array.shape[1] - - mean_across_chain = input_array.mean(axis=1, keepdims=True) - # Compute autocovariance estimates for every lag for the input array using FFT. - centered_array = input_array - mean_across_chain - m = next_fast_len(2 * num_samples) - ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) - ifft_ary *= jnp.conjugate(ifft_ary) - autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) - autocov_value = ( - jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples - ) - mean_autocov_var = autocov_value.mean(0, keepdims=True) - mean_var0 = ( - jnp.take(mean_autocov_var, jnp.array([0]), axis=1) - * num_samples - / (num_samples - 1.0) - ) - weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - - weighted_var = jax.lax.cond( - num_chains > 1, - lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), - lambda _: weighted_var, - operand=None, - ) - - # Geyer's initial positive sequence - num_samples_even = num_samples - num_samples % 2 - mean_autocov_var_tp1 = jnp.take( - mean_autocov_var, jnp.arange(1, num_samples_even), axis=1 - ) - rho_hat = jnp.concatenate( - [ - jnp.ones_like(mean_var0), - 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, - ], - axis=1, - ) - - rho_hat = jnp.moveaxis(rho_hat, 1, 0) - rho_hat_even = rho_hat[0::2] - rho_hat_odd = rho_hat[1::2] - - mask0 = (rho_hat_even + rho_hat_odd) > 0.0 - carry_cond = jnp.ones_like(mask0[0]) - max_t = jnp.zeros_like(mask0[0], dtype=int) - - def positive_sequence_body_fn(state, mask_t): - t, carry_cond, max_t = state - next_mask = carry_cond & mask_t - next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) - return (t + 1, next_mask, next_max_t), next_mask - - (*_, max_t_next), mask = jax.lax.scan( - positive_sequence_body_fn, (0, carry_cond, max_t), mask0 - ) - indices = jnp.indices(max_t_next.shape) - indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) - rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) - # improve estimation - mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) - rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) - - # Geyer's initial monotone sequence - def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): - update_mask = rho_hat_sum_t > rho_hat_sum_tm1 - next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) - return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) - - rho_hat_sum = rho_hat_even + rho_hat_odd - _, (update_mask, update_value) = jax.lax.scan( - monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum - ) - - rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) - rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) - - # compute effective sample size - ess_raw = num_chains * num_samples - tau_hat = ( - -1.0 - + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) - - rho_hat_even_final[indices] - ) - - tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) - ess = ess_raw / tau_hat - - neff = ess.squeeze() / num_samples - return 1.0 / jnp.average(1 / neff) - - -def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): - """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" - - nonans = jnp.all(jnp.isfinite(xx)) - return nonans, *jax.tree_util.tree_map( - lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), - (xx, uu, ll, gg, eps_max, dK), - (x, u, l, g, eps * 0.8, 0.0), - ) - - -def dynamics_adaptive(dynamics, state, L): - """One step of the dynamics with the adaptive stepsize""" - - x, u, l, g, E, Feps, Weps, eps_max, key = state - - eps = jnp.power( - Feps / Weps, -1.0 / 6.0 - ) # We use the Var[E] = O(eps^6) relation here. - eps = (eps < eps_max) * eps + ( - eps > eps_max - ) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - - state, info = dynamics( - jax.random.PRNGKey(0), IntegratorState(x, u, l, g), L=L, step_size=eps - ) - - xx, uu, ll, gg = state - # ll, gg = -ll, -gg - kinetic_change = info.kinetic_change - - varEwanted = 5e-4 - sigma_xi = 1.5 - neff = 150 # effective number of steps used to determine the stepsize in the adaptive step - gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step - - # step updating - success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject( - x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change - ) - - DE = info.dE # energy difference - EE = E + DE # energy - # Warning: var = 0 if there were nans, but we will give it a very small weight - xi = ( - (DE**2) / (xx.shape[0] * varEwanted) - ) + 1e-8 # 1e-8 is added to avoid divergences in log xi - w = jnp.exp( - -0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi)) - ) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. - Feps = gamma * Feps + w * ( - xi / jnp.power(eps, 6.0) - ) # Kalman update the linear combinations - Weps = gamma * Weps + w - - return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success - - -def tune12(kernel, x, u, l, g, random_key, L, eps, num_steps1, num_steps2): - """cheap hyperparameter tuning""" - - def step(state, outer_weight): - """one adaptive step of the dynamics""" - x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive( - kernel, state[0], L - ) - W, F1, F2 = state[1] - w = outer_weight * eps - zero_prevention = 1 - outer_weight - F1 = (W * F1 + w * x) / ( - W + w + zero_prevention - ) # Update with a Kalman filter - F2 = (W * F2 + w * jnp.square(x)) / ( - W + w + zero_prevention - ) # Update with a Kalman filter - W += w - - return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps - - # we use the last num_steps2 to compute the diagonal preconditioner - outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - - # initial state - state = ( - (x, u, l, g, 0.0, jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), - (0.0, jnp.zeros(len(x)), jnp.zeros(len(x))), - ) - # run the steps - state, eps = jax.lax.scan( - step, init=state, xs=outer_weights, length=num_steps1 + num_steps2 - ) - # determine L - if num_steps2 != 0.0: - F1, F2 = state[1][1], state[1][2] - variances = F2 - jnp.square(F1) - sigma2 = jnp.average(variances) - - L = jnp.sqrt(sigma2 * x.shape[0]) - - xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state - return ( - L, - eps[-1], - IntegratorState(xx, uu, ll, gg), - ) # return the tuned hyperparameters and the final state - - -def tune3(kernel, state, rng_key, L, eps, num_steps): - """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - - state, info = jax.lax.scan( - lambda s, k: (kernel(k, s, L, eps)), state, jax.random.split(rng_key, num_steps) - ) - - Lfactor = 0.4 - # ESS2 = effective_sample_size(info.transformed_x) - # neff = ESS2.squeeze() / info.transformed_x.shape[0] - # ESS_alt = 1.0 / jnp.average(1 / neff) - ESS = ess_corr(info.transformed_x) - if ESS * num_steps <= 10: - warnings.warn( - "tune3 cannot be expected to work with 10 or fewer effective samples" - ) - - Lnew = Lfactor * eps / ESS - return Lnew, state - - -def tune( - position, - logdensity_fn, - num_steps: int, - rng_key: PRNGKey, - params: MCLMCAdaptationState, -) -> tuple[MCLMCAdaptationState, IntegratorState]: - num_tune_step_ratio_1 = 0.1 - num_tune_step_ratio_2 = 0.1 - - kernel = build_kernel( - logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x: x - ) - - init_key, tune1_key, tune2_key = jax.random.split(rng_key, 3) - - x, u, l, g = init(position, logdensity_fn=logdensity_fn, rng_key=init_key) - # x, u, l, g = ( - # jnp.array([0.1, 0.1]), - # jnp.array([-0.6755803, 0.73728645]), - # -0.010000001, - # -jnp.array([0.1, 0.1]), - # ) - - L, eps, state = tune12( - kernel, - x, - u, - l, - g, - tune1_key, - params.L, - params.step_size, - int(num_steps * num_tune_step_ratio_1), - int(num_steps * num_tune_step_ratio_1), - ) - - L, state = tune3( - kernel, state, tune2_key, L, eps, int(num_steps * num_tune_step_ratio_2) - ) - return MCLMCAdaptationState(L, eps), state diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index d9e9bff3a..0f89bd77e 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -366,5 +366,6 @@ def noneuclidean_integrator( noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) -noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) +noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) + diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 3700ee821..21b4a2f2d 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -19,7 +19,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey -from blackjax.util import generate_unit_vector, partially_refresh_momentum +from blackjax.util import full_refresh, generate_unit_vector, partially_refresh_momentum __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -47,6 +47,13 @@ class MCLMCInfo(NamedTuple): def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) + jax.debug.print("thing blackjax {x}", x=IntegratorState( + position=x_initial, + momentum=generate_unit_vector(rng_key, x_initial), + logdensity=l, + logdensity_grad=g, + )) + return IntegratorState( position=x_initial, momentum=generate_unit_vector(rng_key, x_initial), diff --git a/blackjax/util.py b/blackjax/util.py index 1c764137d..52c7c634c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,6 +1,7 @@ """Utility functions for BlackJax.""" from functools import partial from typing import Union +import jax import jax.numpy as jnp from jax import jit, lax @@ -103,6 +104,18 @@ def generate_unit_vector( sample = normal(rng_key, shape=p.shape, dtype=p.dtype) return unravel_fn(sample / jnp.linalg.norm(sample)) +def full_refresh(d): + """Generates a random (isotropic) unit vector.""" + + + def rng(random_key): + key, subkey = jax.random.split(random_key) + u = jax.random.normal(jax.random.PRNGKey(0), shape = (d, )) + u /= jnp.sqrt(jnp.sum(jnp.square(u))) + return u, key + + + return rng def partially_refresh_momentum(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. From 9a41cdfda6f840c283b3834b8c35e3e9f42e962b Mon Sep 17 00:00:00 2001 From: = Date: Fri, 1 Dec 2023 19:11:46 +0100 Subject: [PATCH 56/78] align blackjax and mclmc repos, for tuning --- blackjax/adaptation/mclmc_adaptation.py | 284 ++++++++---------------- 1 file changed, 97 insertions(+), 187 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 4ef964a44..30e5f3a3b 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -137,136 +137,55 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): neff = ess.squeeze() / num_samples return 1.0 / jnp.average(1 / neff) +import jax +import jax.numpy as jnp +from typing import NamedTuple -# def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): -# """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" - -# nonans = jnp.all(jnp.isfinite(xx)) -# return nonans, *jax.tree_util.tree_map( -# lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), -# (xx, uu, ll, gg, eps_max, dK), -# (x, u, l, g, eps * 0.8, 0.0), -# ) - - -# def dynamics_adaptive(dynamics, state, L): -# """One step of the dynamics with the adaptive stepsize""" - -# x, u, l, g, E, Feps, Weps, eps_max, key = state - -# eps = jnp.power( -# Feps / Weps, -1.0 / 6.0 -# ) # We use the Var[E] = O(eps^6) relation here. -# eps = (eps < eps_max) * eps + ( -# eps > eps_max -# ) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - -# state, info = dynamics( -# jax.random.PRNGKey(0), IntegratorState(x, u, l, g), L=L, step_size=eps -# ) - -# xx, uu, ll, gg = state -# # ll, gg = -ll, -gg -# kinetic_change = info.kinetic_change - -# varEwanted = 5e-4 -# sigma_xi = 1.5 -# neff = 150 # effective number of steps used to determine the stepsize in the adaptive step -# gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step - -# # step updating -# success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject( -# x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change -# ) - -# DE = info.dE # energy difference -# EE = E + DE # energy -# # Warning: var = 0 if there were nans, but we will give it a very small weight -# xi = ( -# (DE**2) / (xx.shape[0] * varEwanted) -# ) + 1e-8 # 1e-8 is added to avoid divergences in log xi -# w = jnp.exp( -# -0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi)) -# ) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one. -# Feps = gamma * Feps + w * ( -# xi / jnp.power(eps, 6.0) -# ) # Kalman update the linear combinations -# Weps = gamma * Weps + w - -# return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success - - -# def tune12(kernel, x, u, l, g, random_key, L, eps, num_steps1, num_steps2): -# """cheap hyperparameter tuning""" - -# def step(state, outer_weight): -# """one adaptive step of the dynamics""" -# x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive( -# kernel, state[0], L -# ) -# W, F1, F2 = state[1] -# w = outer_weight * eps -# zero_prevention = 1 - outer_weight -# F1 = (W * F1 + w * x) / ( -# W + w + zero_prevention -# ) # Update with a Kalman filter -# F2 = (W * F2 + w * jnp.square(x)) / ( -# W + w + zero_prevention -# ) # Update with a Kalman filter -# W += w - -# return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps - -# # we use the last num_steps2 to compute the diagonal preconditioner -# outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - -# # initial state -# state = ( -# (x, u, l, g, 0.0, jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key), -# (0.0, jnp.zeros(len(x)), jnp.zeros(len(x))), -# ) -# # run the steps -# state, eps = jax.lax.scan( -# step, init=state, xs=outer_weights, length=num_steps1 + num_steps2 -# ) -# # determine L -# if num_steps2 != 0.0: -# F1, F2 = state[1][1], state[1][2] -# variances = F2 - jnp.square(F1) -# sigma2 = jnp.average(variances) - -# L = jnp.sqrt(sigma2 * x.shape[0]) - -# xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state -# return ( -# L, -# eps[-1], -# IntegratorState(xx, uu, ll, gg), -# ) # return the tuned hyperparameters and the final state - -# adapt_L_on_ess -def adapt_L_on_ess(kernel, state, rng_key, params, num_steps): - """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - state, info = jax.lax.scan( - lambda s, k: (kernel(k, s, params.L, params.step_size)), state, jax.random.split(rng_key, num_steps) - ) - Lfactor = 0.4 - # ESS2 = effective_sample_size(info.transformed_x) - # neff = ESS2.squeeze() / info.transformed_x.shape[0] - # ESS_alt = 1.0 / jnp.average(1 / neff) - ESS = ess_corr(info.transformed_x) - if ESS * num_steps <= 10: - warnings.warn( - "tune3 cannot be expected to work with 10 or fewer effective samples" - ) - Lnew = Lfactor * params.step_size / ESS - return Lnew, state +def mclmc_find_L_and_step_size(kernel, num_steps, initial_state): + + d = initial_state.position.shape[0] + dyn = initial_state + hyp = MCLMCAdaptationState(jnp.sqrt(d), + jnp.sqrt(d) * 0.25, + ) + + frac_tune1 = 0.1 + frac_tune2 = 0.1 + frac_tune3 = 0.1 + varEwanted = 5e-4 + tune12p = tune12(kernel, d, False, jnp.array([frac_tune1, frac_tune2]), varEwanted, 1.5, 150) + + tune3p = tune3(kernel, frac_tune3, 0.4) + + if frac_tune3 != 0.: + tune3p = tune3(kernel, frac= frac_tune3, Lfactor= 0.4) + schedule = [tune12p, tune3p] + else: + schedule = [tune12p, ] + + dyn, hyp = run(dyn, hyp, schedule, num_steps) + return dyn, hyp + +# all tuning functions are wrappers, recieving some parameters and returning a function +# func(dyn, hyp, num_total_steps) -> (dyn, hyp) -# def adapt_L_step_size(kernel, state, rng_key, params, num_steps1, num_steps2): + +def run(dyn, hyp, schedule, num_steps): + + _dyn, _hyp = dyn, hyp + + for program in schedule: + _dyn, _hyp = program(_dyn, _hyp, num_steps) + + return _dyn, _hyp + + + + def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" @@ -276,11 +195,16 @@ def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): (x, u, l, g, eps * 0.8, 0.)) return nonans, _x, _u, _l, _g, _eps, _dk + + -def adapt_L_step_size(dynamics, d, frac, - varEwanted = 1e-3, sigma_xi = 1.5, neff = 150): - print("Starting tune12 (blackjax)") +def tune12(kernel, d, + diag_precond, frac, + varEwanted = 1e-3, sigma_xi = 1.5, neff = 150): + + print("Starting tune12") + gamma_forget = (neff - 1.0) / (neff + 1.0) @@ -291,24 +215,25 @@ def predictor(dyn_old, hyp, adaptive_state): W, F, eps_max = adaptive_state # dynamics - dyn_new, energy_change = dynamics(dyn_old, hyp) - + # dyn_new, energy_change = dynamics(dyn_old, hyp) + dyn_new, info = kernel(rng_key = jax.random.PRNGKey(0), state=dyn_old, L=hyp.L, step_size=hyp.step_size) + energy_change = info.dE # step updating - success, x, u, l, g, eps_max, energy_change = nan_reject(dyn_old.x, dyn_old.u, dyn_old.l, dyn_old.g, - dyn_new.x, dyn_new.u, dyn_new.l, dyn_new.g, - hyp.eps, eps_max, energy_change) + success, x, u, l, g, eps_max, energy_change = nan_reject(dyn_old.position, dyn_old.momentum, dyn_old.logdensity, dyn_old.logdensity_grad, + dyn_new.position, dyn_new.momentum, dyn_new.logdensity, dyn_new.logdensity_grad, + hyp.step_size, eps_max, energy_change) - dyn = State(x, u, l, g, dyn_new.key) + dyn = IntegratorState(x, u, l, g) # Warning: var = 0 if there were nans, but we will give it a very small weight xi = (jnp.square(energy_change) / (d * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. - F = gamma_forget * F + w * (xi/jnp.power(hyp.eps, 6.0)) + F = gamma_forget * F + w * (xi/jnp.power(hyp.step_size, 6.0)) W = gamma_forget * W + w eps = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - hyp_new = Hyperparameters(hyp.L, eps, hyp.sigma) + hyp_new = MCLMCAdaptationState(hyp.L, eps) return dyn, hyp_new, hyp_new, (W, F, eps_max), success @@ -332,7 +257,7 @@ def step(state, outer_weight): """does one step of the dynamcis and updates the estimate of the posterior size and optimal stepsize""" dyn, hyp, _, adaptive_state, kalman_state = state dyn, hyp, hyp_final, adaptive_state, success = _step(dyn, hyp, adaptive_state) - kalman_state = update_kalman(dyn.x, kalman_state, outer_weight, success, hyp.eps) + kalman_state = update_kalman(dyn.position, kalman_state, outer_weight, success, hyp.step_size) return (dyn, hyp, hyp_final, adaptive_state, kalman_state), None @@ -353,75 +278,60 @@ def func(_dyn, _hyp, num_steps): dyn, _, hyp, adap, kalman_state = state L = hyp.L - sigma = hyp.sigma # determine L if num_steps2 != 0.: _, F1, F2 = kalman_state variances = F2 - jnp.square(F1) L = jnp.sqrt(jnp.sum(variances)) - # optionally we do the diagonal preconditioning (and readjust the stepsize) - if diag_precond: + # # optionally we do the diagonal preconditioning (and readjust the stepsize) + # if diag_precond: - # diagonal preconditioning - sigma = jnp.sqrt(variances) - L = jnp.sqrt(d) + # # diagonal preconditioning + # sigma = jnp.sqrt(variances) + # L = jnp.sqrt(d) - #readjust the stepsize - steps = num_steps2 // 3 #we do some small number of steps - state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] - dyn, _, hyp, adap, kalman_state = state - else: - sigma = hyp.sigma + # #readjust the stepsize + # steps = num_steps2 // 3 #we do some small number of steps + # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] + # dyn, _, hyp, adap, kalman_state = state + # else: + # sigma = hyp.sigma - jax.debug.print(" \n\n\nPARAMS:\n{x}", x=(dyn,Hyperparameters(L, hyp.eps, sigma) )) - return dyn, Hyperparameters(L, hyp.eps, sigma) + # jax.debug.print(" \n\n\nPARAMS:\n{x}", x=(dyn,MCLMCAdaptationState(L, hyp.step_size) )) + return dyn, MCLMCAdaptationState(L, hyp.step_size) return func -def mclmc_find_L_and_step_size( - position, - logdensity_fn, - num_steps: int, - rng_key: PRNGKey, - params: MCLMCAdaptationState, -) -> tuple[MCLMCAdaptationState, IntegratorState]: - num_tune_step_ratio_1 = 0.1 - num_tune_step_ratio_2 = 0.1 - num_tune_step_ratio_3 = 0.1 - kernel = build_kernel( - logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x: x - ) - init_key, tune1_key, tune2_key = jax.random.split(rng_key, 3) +def tune3(kernel, frac, Lfactor): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + - state = init(position, logdensity_fn=logdensity_fn, rng_key=init_key) - # x, u, l, g = ( - # jnp.array([0.1, 0.1]), - # jnp.array([-0.6755803, 0.73728645]), - # -0.010000001, - # -jnp.array([0.1, 0.1]), - # ) + def sample_full(num_steps, _dyn, hyp): + """Stores full x for each step. Used in tune2.""" - # jax.debug.print("eq {x}", x=(state.momentum,jnp.array([ 0.92340476, -0.38382764])) ) - # jax.debug.print("eq {x}", x=(state.position,jnp.array([ 1.,1.])) ) - - # position=state,IntegratorState(jnp.array([1., 1.]), jnp.array([ 0.92340476, -0.38382764]), jnp.array(1.), jnp.array([1., 1.]))) - - # MCLMCAdaptationState(L=jax.Array(1.41421356), eps=jax.Array(0.35355339, weak_type=True), sigma=jax.Array([1., 1.])) + def _step(state, useless): + dyn_old = state + # dyn_new, _ = step(dyn_old, hyp) + dyn_new, _ = kernel(rng_key=jax.random.PRNGKey(0), state=dyn_old, L=hyp.L, step_size=hyp.step_size) + + return dyn_new, dyn_new.position - varEwanted = 5e-4 - d = state.position.shape[0] + return jax.lax.scan(_step, init=_dyn, xs=None, length=num_steps) - - params, state = adapt_L_step_size(kernel, d, jnp.array([num_tune_step_ratio_1, num_tune_step_ratio_2]), varEwanted, 1.5, 150)(state, params, num_steps) - - L, step_size = params + def func(dyn, hyp, num_steps): + steps = jnp.rint(num_steps * frac).astype(int) + + dyn, X = sample_full(steps, dyn, hyp) + ESS = ess_corr(X) # num steps / effective sample size + Lnew = Lfactor * hyp.step_size / ESS # = 0.4 * length corresponding to one effective sample + + return dyn, MCLMCAdaptationState(Lnew, hyp.step_size) + + + return func - L, state = adapt_L_on_ess( - kernel, state, tune2_key, params, int(num_steps * num_tune_step_ratio_3) - ) - return MCLMCAdaptationState(L, step_size), state From cdbb4f6f6833f33f737fbbb0d273503ae2c73ce6 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 1 Dec 2023 20:16:59 +0100 Subject: [PATCH 57/78] use effective_sample_size --- blackjax/adaptation/mclmc_adaptation.py | 138 +++--------------------- 1 file changed, 17 insertions(+), 121 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 30e5f3a3b..a0a59c0f6 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -21,10 +21,14 @@ from chex import PRNGKey import jax import jax.numpy as jnp -from scipy.fftpack import next_fast_len #type: ignore +from scipy.fftpack import next_fast_len +from blackjax.diagnostics import effective_sample_size #type: ignore from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.mcmc.mclmc import build_kernel, init +import jax +import jax.numpy as jnp +from typing import NamedTuple class MCLMCAdaptationState(NamedTuple): @@ -33,130 +37,27 @@ class MCLMCAdaptationState(NamedTuple): L: float step_size: float - def ess_corr(x): - """Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html - shape(x) = (num_samples, d)""" - - input_array = jnp.array( - [ - x, - ] - ) - - num_chains = 1 # input_array.shape[0] - num_samples = input_array.shape[1] - - mean_across_chain = input_array.mean(axis=1, keepdims=True) - # Compute autocovariance estimates for every lag for the input array using FFT. - centered_array = input_array - mean_across_chain - m = next_fast_len(2 * num_samples) - ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1) - ifft_ary *= jnp.conjugate(ifft_ary) - autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1) - autocov_value = ( - jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples - ) - mean_autocov_var = autocov_value.mean(0, keepdims=True) - mean_var0 = ( - jnp.take(mean_autocov_var, jnp.array([0]), axis=1) - * num_samples - / (num_samples - 1.0) - ) - weighted_var = mean_var0 * (num_samples - 1.0) / num_samples - - weighted_var = jax.lax.cond( - num_chains > 1, - lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True), - lambda _: weighted_var, - operand=None, - ) - - # Geyer's initial positive sequence - num_samples_even = num_samples - num_samples % 2 - mean_autocov_var_tp1 = jnp.take( - mean_autocov_var, jnp.arange(1, num_samples_even), axis=1 - ) - rho_hat = jnp.concatenate( - [ - jnp.ones_like(mean_var0), - 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, - ], - axis=1, - ) - - rho_hat = jnp.moveaxis(rho_hat, 1, 0) - rho_hat_even = rho_hat[0::2] - rho_hat_odd = rho_hat[1::2] - - mask0 = (rho_hat_even + rho_hat_odd) > 0.0 - carry_cond = jnp.ones_like(mask0[0]) - max_t = jnp.zeros_like(mask0[0], dtype=int) - - def positive_sequence_body_fn(state, mask_t): - t, carry_cond, max_t = state - next_mask = carry_cond & mask_t - next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) - return (t + 1, next_mask, next_max_t), next_mask - - (*_, max_t_next), mask = jax.lax.scan( - positive_sequence_body_fn, (0, carry_cond, max_t), mask0 - ) - indices = jnp.indices(max_t_next.shape) - indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) - rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) - # improve estimation - mask_even = mask.at[indices].set(rho_hat_even[indices] > 0) - rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) - - # Geyer's initial monotone sequence - def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): - update_mask = rho_hat_sum_t > rho_hat_sum_tm1 - next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) - return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) - - rho_hat_sum = rho_hat_even + rho_hat_odd - _, (update_mask, update_value) = jax.lax.scan( - monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum - ) - - rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) - rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) - - # compute effective sample size - ess_raw = num_chains * num_samples - tau_hat = ( - -1.0 - + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) - - rho_hat_even_final[indices] - ) - - tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw)) - ess = ess_raw / tau_hat - - neff = ess.squeeze() / num_samples - return 1.0 / jnp.average(1 / neff) + num_samples = x.shape[0] + ess = 0.5 * effective_sample_size(jnp.array([x, x])) + ess_per_sample = ess / num_samples + return 1.0 / jnp.average(1 / ess_per_sample) -import jax -import jax.numpy as jnp -from typing import NamedTuple +def mclmc_find_L_and_step_size(kernel, num_steps, initial_state, frac_tune1 = 0.1, + frac_tune2 = 0.1, + frac_tune3 = 0.1): - -def mclmc_find_L_and_step_size(kernel, num_steps, initial_state): - - d = initial_state.position.shape[0] + dim = initial_state.position.shape[0] dyn = initial_state - hyp = MCLMCAdaptationState(jnp.sqrt(d), - jnp.sqrt(d) * 0.25, + hyp = MCLMCAdaptationState(jnp.sqrt(dim), + jnp.sqrt(dim) * 0.25, ) - frac_tune1 = 0.1 - frac_tune2 = 0.1 - frac_tune3 = 0.1 + varEwanted = 5e-4 - tune12p = tune12(kernel, d, False, jnp.array([frac_tune1, frac_tune2]), varEwanted, 1.5, 150) + tune12p = tune12(kernel, dim, False, jnp.array([frac_tune1, frac_tune2]), varEwanted, 1.5, 150) tune3p = tune3(kernel, frac_tune3, 0.4) @@ -169,11 +70,6 @@ def mclmc_find_L_and_step_size(kernel, num_steps, initial_state): dyn, hyp = run(dyn, hyp, schedule, num_steps) return dyn, hyp -# all tuning functions are wrappers, recieving some parameters and returning a function -# func(dyn, hyp, num_total_steps) -> (dyn, hyp) - - - def run(dyn, hyp, schedule, num_steps): _dyn, _hyp = dyn, hyp From 947d7176a745882a1bd2263400a51c4f663ae065 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 1 Dec 2023 20:25:09 +0100 Subject: [PATCH 58/78] patial rename --- blackjax/adaptation/mclmc_adaptation.py | 63 +++++++++++++++++-------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index a0a59c0f6..b2ae6d621 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -31,24 +31,56 @@ from typing import NamedTuple +from typing import NamedTuple + class MCLMCAdaptationState(NamedTuple): - """Tunable parameters for MCLMC""" + """Represents the tunable parameters for MCLMC adaptation. + Attributes: + L (float): The momentum decoherent rate for the MCLMC algorithm. + step_size (float): The step size used for the MCLMC algorithm. + """ + L: float step_size: float -def ess_corr(x): - num_samples = x.shape[0] - ess = 0.5 * effective_sample_size(jnp.array([x, x])) +def ess_corr(samples): + """ + Calculates the effective sample size correction for a given set of samples. + + Parameters: + x (ndarray): Array of samples. + + A light wrapper around the blackjax.diagnostics.effective_sample_size function. + + Returns: + float: The effective sample size correction. + """ + num_samples = samples.shape[0] + ess = 0.5 * effective_sample_size(jnp.array([samples, samples])) ess_per_sample = ess / num_samples return 1.0 / jnp.average(1 / ess_per_sample) -def mclmc_find_L_and_step_size(kernel, num_steps, initial_state, frac_tune1 = 0.1, - frac_tune2 = 0.1, - frac_tune3 = 0.1): +def mclmc_find_L_and_step_size(kernel, num_steps, initial_state, frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1): + """ + Finds the optimal value of L (step size) for the MCLMC algorithm. + Args: + kernel: The kernel function used for the MCMC algorithm. + num_steps: The number of MCMC steps that will subsequently be run, after tuning + initial_state: The initial state of the MCMC algorithm. + frac_tune1: The fraction of tuning for the first step of the adaptation. + frac_tune2: The fraction of tuning for the second step of the adaptation. + frac_tune3: The fraction of tuning for the third step of the adaptation. + + Returns: + dyn: The final state of the MCMC algorithm. + hyp: The final hyperparameters of the MCMC algorithm. + """ dim = initial_state.position.shape[0] dyn = initial_state hyp = MCLMCAdaptationState(jnp.sqrt(dim), @@ -62,26 +94,19 @@ def mclmc_find_L_and_step_size(kernel, num_steps, initial_state, frac_tune1 = 0. tune3p = tune3(kernel, frac_tune3, 0.4) if frac_tune3 != 0.: - tune3p = tune3(kernel, frac= frac_tune3, Lfactor= 0.4) + tune3p = tune3(kernel, frac=frac_tune3, Lfactor=0.4) schedule = [tune12p, tune3p] else: schedule = [tune12p, ] - dyn, hyp = run(dyn, hyp, schedule, num_steps) - return dyn, hyp - -def run(dyn, hyp, schedule, num_steps): - - _dyn, _hyp = dyn, hyp - for program in schedule: - _dyn, _hyp = program(_dyn, _hyp, num_steps) + dyn, hyp = program(dyn, hyp, num_steps) - return _dyn, _hyp - - + return dyn, hyp + + def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" From e9ab7b4e8ee70fa1704d359478b9d5324b3f0f3e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 1 Dec 2023 23:18:13 +0100 Subject: [PATCH 59/78] rename --- blackjax/adaptation/mclmc_adaptation.py | 185 +++++++++--------------- 1 file changed, 72 insertions(+), 113 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index b2ae6d621..1c362aefd 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -33,6 +33,8 @@ from typing import NamedTuple +from blackjax.util import pytree_size + class MCLMCAdaptationState(NamedTuple): """Represents the tunable parameters for MCLMC adaptation. @@ -44,26 +46,7 @@ class MCLMCAdaptationState(NamedTuple): L: float step_size: float -def ess_corr(samples): - """ - Calculates the effective sample size correction for a given set of samples. - - Parameters: - x (ndarray): Array of samples. - - A light wrapper around the blackjax.diagnostics.effective_sample_size function. - - Returns: - float: The effective sample size correction. - """ - num_samples = samples.shape[0] - ess = 0.5 * effective_sample_size(jnp.array([samples, samples])) - ess_per_sample = ess / num_samples - return 1.0 / jnp.average(1 / ess_per_sample) - - - -def mclmc_find_L_and_step_size(kernel, num_steps, initial_state, frac_tune1=0.1, +def mclmc_find_L_and_step_size(kernel, num_steps, state, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1): """ @@ -72,97 +55,74 @@ def mclmc_find_L_and_step_size(kernel, num_steps, initial_state, frac_tune1=0.1, Args: kernel: The kernel function used for the MCMC algorithm. num_steps: The number of MCMC steps that will subsequently be run, after tuning - initial_state: The initial state of the MCMC algorithm. + state: The initial state of the MCMC algorithm. frac_tune1: The fraction of tuning for the first step of the adaptation. frac_tune2: The fraction of tuning for the second step of the adaptation. frac_tune3: The fraction of tuning for the third step of the adaptation. Returns: - dyn: The final state of the MCMC algorithm. - hyp: The final hyperparameters of the MCMC algorithm. + state: The final state of the MCMC algorithm. + params: The final hyperparameters of the MCMC algorithm. """ - dim = initial_state.position.shape[0] - dyn = initial_state - hyp = MCLMCAdaptationState(jnp.sqrt(dim), + dim = state.position.shape[0] + params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, ) + varEwanted = 5e-4 - varEwanted = 5e-4 - tune12p = tune12(kernel, dim, False, jnp.array([frac_tune1, frac_tune2]), varEwanted, 1.5, 150) + state, params = make_L_step_size_adaptation(kernel, dim, jnp.array([frac_tune1, frac_tune2]), varEwanted, 1.5, 150)(state, params, num_steps) - tune3p = tune3(kernel, frac_tune3, 0.4) + if frac_tune3 != 0: - if frac_tune3 != 0.: - tune3p = tune3(kernel, frac=frac_tune3, Lfactor=0.4) - schedule = [tune12p, tune3p] - else: - schedule = [tune12p, ] + state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)(state,params, num_steps) + - for program in schedule: - dyn, hyp = program(dyn, hyp, num_steps) - return dyn, hyp + return state, params - -def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK): - """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" - - nonans = jnp.all(jnp.isfinite(xx)) - _x, _u, _l, _g, _eps, _dk = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), - (xx, uu, ll, gg, eps_max, dK), - (x, u, l, g, eps * 0.8, 0.)) - - return nonans, _x, _u, _l, _g, _eps, _dk - - - - -def tune12(kernel, d, - diag_precond, frac, +def make_L_step_size_adaptation(kernel, d, frac, varEwanted = 1e-3, sigma_xi = 1.5, neff = 150): - - print("Starting tune12") - + """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" + gamma_forget = (neff - 1.0) / (neff + 1.0) - def predictor(dyn_old, hyp, adaptive_state): + def predictor(state_old, state, adaptive_state): """does one step with the dynamics and updates the prediction for the optimal stepsize Designed for the unadjusted MCHMC""" - W, F, eps_max = adaptive_state + W, F, step_size_max = adaptive_state - # dynamics - # dyn_new, energy_change = dynamics(dyn_old, hyp) - dyn_new, info = kernel(rng_key = jax.random.PRNGKey(0), state=dyn_old, L=hyp.L, step_size=hyp.step_size) + # stateamics + # state_new, energy_change = stateamics(state_old, state) + state_new, info = kernel(rng_key = jax.random.PRNGKey(0), state=state_old, L=state.L, step_size=state.step_size) energy_change = info.dE # step updating - success, x, u, l, g, eps_max, energy_change = nan_reject(dyn_old.position, dyn_old.momentum, dyn_old.logdensity, dyn_old.logdensity_grad, - dyn_new.position, dyn_new.momentum, dyn_new.logdensity, dyn_new.logdensity_grad, - hyp.step_size, eps_max, energy_change) + success, x, u, l, g, step_size_max, energy_change = handle_nans(state_old.position, state_old.momentum, state_old.logdensity, state_old.logdensity_grad, + state_new.position, state_new.momentum, state_new.logdensity, state_new.logdensity_grad, + state.step_size, step_size_max, energy_change) - dyn = IntegratorState(x, u, l, g) # Warning: var = 0 if there were nans, but we will give it a very small weight xi = (jnp.square(energy_change) / (d * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. - F = gamma_forget * F + w * (xi/jnp.power(hyp.step_size, 6.0)) + F = gamma_forget * F + w * (xi/jnp.power(state.step_size, 6.0)) W = gamma_forget * W + w - eps = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. - eps = (eps < eps_max) * eps + (eps > eps_max) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences - hyp_new = MCLMCAdaptationState(hyp.L, eps) + step_size = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. + step_size = (step_size < step_size_max) * step_size + (step_size > step_size_max) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + state_new = MCLMCAdaptationState(state.L, step_size) - return dyn, hyp_new, hyp_new, (W, F, eps_max), success + return IntegratorState(x, u, l, g), state_new, state_new, (W, F, step_size_max), success - def update_kalman(x, state, outer_weight, success, eps): + def update_kalman(x, state, outer_weight, success, step_size): """kalman filter to estimate the size of the posterior""" W, F1, F2 = state - w = outer_weight * eps * success + w = outer_weight * step_size * success zero_prevention = 1-outer_weight F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter @@ -175,15 +135,15 @@ def update_kalman(x, state, outer_weight, success, eps): def step(state, outer_weight): - """does one step of the dynamcis and updates the estimate of the posterior size and optimal stepsize""" - dyn, hyp, _, adaptive_state, kalman_state = state - dyn, hyp, hyp_final, adaptive_state, success = _step(dyn, hyp, adaptive_state) - kalman_state = update_kalman(dyn.position, kalman_state, outer_weight, success, hyp.step_size) + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + state, params, _, adaptive_state, kalman_state = state + state, params, params_final, adaptive_state, success = _step(state, params, adaptive_state) + kalman_state = update_kalman(state.position, kalman_state, outer_weight, success, params.step_size) - return (dyn, hyp, hyp_final, adaptive_state, kalman_state), None + return (state, params, params_final, adaptive_state, kalman_state), None - def func(_dyn, _hyp, num_steps): + def L_step_size_adaptation(state, params, num_steps): num_steps1, num_steps2 = jnp.rint(num_steps * frac).astype(int) @@ -195,64 +155,63 @@ def func(_dyn, _hyp, num_steps): kalman_state = (0., jnp.zeros(d), jnp.zeros(d)) # run the steps - state = jax.lax.scan(step, init= (_dyn, _hyp, _hyp, adap0, kalman_state), xs= outer_weights, length= num_steps1 + num_steps2)[0] - dyn, _, hyp, adap, kalman_state = state + kalman_state = jax.lax.scan(step, init= (state, params, params, adap0, kalman_state), xs= outer_weights, length= num_steps1 + num_steps2)[0] + state, _, params, _, kalman_state_output = kalman_state - L = hyp.L + L = params.L # determine L if num_steps2 != 0.: - _, F1, F2 = kalman_state + _, F1, F2 = kalman_state_output variances = F2 - jnp.square(F1) L = jnp.sqrt(jnp.sum(variances)) - # # optionally we do the diagonal preconditioning (and readjust the stepsize) - # if diag_precond: - # # diagonal preconditioning - # sigma = jnp.sqrt(variances) - # L = jnp.sqrt(d) - - # #readjust the stepsize - # steps = num_steps2 // 3 #we do some small number of steps - # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] - # dyn, _, hyp, adap, kalman_state = state - # else: - # sigma = hyp.sigma - - # jax.debug.print(" \n\n\nPARAMS:\n{x}", x=(dyn,MCLMCAdaptationState(L, hyp.step_size) )) - return dyn, MCLMCAdaptationState(L, hyp.step_size) + return state, MCLMCAdaptationState(L, params.step_size) - return func + return L_step_size_adaptation -def tune3(kernel, frac, Lfactor): +def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - def sample_full(num_steps, _dyn, hyp): - """Stores full x for each step. Used in tune2.""" + def sample_full(num_steps, state, params): - def _step(state, useless): - dyn_old = state - # dyn_new, _ = step(dyn_old, hyp) - dyn_new, _ = kernel(rng_key=jax.random.PRNGKey(0), state=dyn_old, L=hyp.L, step_size=hyp.step_size) + def step(state, _): + state, _ = kernel(rng_key=jax.random.PRNGKey(0), state=state, L=params.L, step_size=params.step_size) - return dyn_new, dyn_new.position + return state, state.position - return jax.lax.scan(_step, init=_dyn, xs=None, length=num_steps) + return jax.lax.scan(step, init=state, xs=None, length=num_steps) - def func(dyn, hyp, num_steps): + def adaptation_L(state, params, num_steps): steps = jnp.rint(num_steps * frac).astype(int) - dyn, X = sample_full(steps, dyn, hyp) - ESS = ess_corr(X) # num steps / effective sample size - Lnew = Lfactor * hyp.step_size / ESS # = 0.4 * length corresponding to one effective sample + state, samples = sample_full(steps, state, params) + num_samples = samples.shape[0] + ESS = 0.5 * effective_sample_size(jnp.array([samples, samples])) + ess_per_sample = ESS / num_samples + ESS = 1.0 / jnp.average(1 / ess_per_sample) - return dyn, MCLMCAdaptationState(Lnew, hyp.step_size) + Lnew = Lfactor * params.step_size / ESS # = 0.4 * length corresponding to one effective sample + return state, MCLMCAdaptationState(Lnew, params.step_size) - return func + return adaptation_L + + +def handle_nans(x, u, l, g, xx, uu, ll, gg, step_size, step_size_max, dK): + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + + nonans = jnp.all(jnp.isfinite(xx)) + x, u, l, g, step_size, dk = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (xx, uu, ll, gg, step_size_max, dK), + (x, u, l, g, step_size * 0.8, 0.)) + + return nonans, x, u, l, g, step_size, dk + + \ No newline at end of file From 72d70c6da10a64cc64c98aecebc3020bb751a76e Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 00:47:00 +0100 Subject: [PATCH 60/78] clean up tuning --- blackjax/adaptation/mclmc_adaptation.py | 133 ++++++++++-------------- blackjax/mcmc/mclmc.py | 5 +- blackjax/util.py | 13 --- 3 files changed, 59 insertions(+), 92 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 1c362aefd..b1eb9fd0c 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -16,25 +16,18 @@ """ from typing import NamedTuple -import warnings -from chex import PRNGKey import jax import jax.numpy as jnp -from scipy.fftpack import next_fast_len from blackjax.diagnostics import effective_sample_size #type: ignore - -from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan -from blackjax.mcmc.mclmc import build_kernel, init import jax import jax.numpy as jnp from typing import NamedTuple - - from typing import NamedTuple from blackjax.util import pytree_size + class MCLMCAdaptationState(NamedTuple): """Represents the tunable parameters for MCLMC adaptation. @@ -46,9 +39,9 @@ class MCLMCAdaptationState(NamedTuple): L: float step_size: float -def mclmc_find_L_and_step_size(kernel, num_steps, state, frac_tune1=0.1, +def mclmc_find_L_and_step_size(kernel, num_steps, state, part1_key, part2_key, frac_tune1=0.1, frac_tune2=0.1, - frac_tune3=0.1): + frac_tune3=0.1, ): """ Finds the optimal value of L (step size) for the MCLMC algorithm. @@ -64,66 +57,58 @@ def mclmc_find_L_and_step_size(kernel, num_steps, state, frac_tune1=0.1, state: The final state of the MCMC algorithm. params: The final hyperparameters of the MCMC algorithm. """ - dim = state.position.shape[0] - params = MCLMCAdaptationState(jnp.sqrt(dim), - jnp.sqrt(dim) * 0.25, - ) + dim = pytree_size(state.position) + params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) varEwanted = 5e-4 - - state, params = make_L_step_size_adaptation(kernel, dim, jnp.array([frac_tune1, frac_tune2]), varEwanted, 1.5, 150)(state, params, num_steps) + state, params = make_L_step_size_adaptation(kernel=kernel, dim=dim, frac_tune1=frac_tune1, frac_tune2=frac_tune2, varEwanted=varEwanted, sigma_xi=1.5, num_effective_samples=150)(state, params, num_steps, part1_key) if frac_tune3 != 0: - - state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)(state,params, num_steps) - - + state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)(state,params, num_steps, part2_key) return state, params -def make_L_step_size_adaptation(kernel, d, frac, - varEwanted = 1e-3, sigma_xi = 1.5, neff = 150): +def make_L_step_size_adaptation(kernel, dim, frac_tune1, frac_tune2, + varEwanted = 1e-3, sigma_xi = 1.5, num_effective_samples = 150): """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" - gamma_forget = (neff - 1.0) / (neff + 1.0) + gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) - def predictor(state_old, state, adaptive_state): + def predictor(state_old, params, adaptive_state, rng_key): """does one step with the dynamics and updates the prediction for the optimal stepsize Designed for the unadjusted MCHMC""" W, F, step_size_max = adaptive_state - # stateamics - # state_new, energy_change = stateamics(state_old, state) - state_new, info = kernel(rng_key = jax.random.PRNGKey(0), state=state_old, L=state.L, step_size=state.step_size) + # dynamics + state_new, info = kernel(rng_key = rng_key, state=state_old, L=params.L, step_size=params.step_size) energy_change = info.dE # step updating - success, x, u, l, g, step_size_max, energy_change = handle_nans(state_old.position, state_old.momentum, state_old.logdensity, state_old.logdensity_grad, - state_new.position, state_new.momentum, state_new.logdensity, state_new.logdensity_grad, - state.step_size, step_size_max, energy_change) + success, state, step_size_max, energy_change = handle_nans(state_old,state_new, + params.step_size, step_size_max, energy_change) # Warning: var = 0 if there were nans, but we will give it a very small weight - xi = (jnp.square(energy_change) / (d * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi + xi = (jnp.square(energy_change) / (dim * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. - F = gamma_forget * F + w * (xi/jnp.power(state.step_size, 6.0)) + F = gamma_forget * F + w * (xi/jnp.power(params.step_size, 6.0)) W = gamma_forget * W + w step_size = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. step_size = (step_size < step_size_max) * step_size + (step_size > step_size_max) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences - state_new = MCLMCAdaptationState(state.L, step_size) + params_new = params._replace(step_size=step_size) - return IntegratorState(x, u, l, g), state_new, state_new, (W, F, step_size_max), success + return state, params_new, params_new, (W, F, step_size_max), success def update_kalman(x, state, outer_weight, success, step_size): """kalman filter to estimate the size of the posterior""" W, F1, F2 = state w = outer_weight * step_size * success - zero_prevention = 1-outer_weight + zero_prevention = 1 - outer_weight F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter W += w @@ -131,32 +116,36 @@ def update_kalman(x, state, outer_weight, success, step_size): adap0 = (0., 0., jnp.inf) - _step = predictor - def step(state, outer_weight): + def step(iteration_state, weight_and_key): + outer_weight, rng_key = weight_and_key """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" - state, params, _, adaptive_state, kalman_state = state - state, params, params_final, adaptive_state, success = _step(state, params, adaptive_state) + state, params, adaptive_state, kalman_state = iteration_state + state, params, params_final, adaptive_state, success = predictor(state, params, adaptive_state, rng_key) kalman_state = update_kalman(state.position, kalman_state, outer_weight, success, params.step_size) - return (state, params, params_final, adaptive_state, kalman_state), None + return (state, params_final, adaptive_state, kalman_state), None - def L_step_size_adaptation(state, params, num_steps): + def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = jnp.rint(num_steps * frac).astype(int) + num_steps1, num_steps2 = int(num_steps * frac_tune1), int(num_steps*frac_tune2) + # TODO: change below to use jax.random.split + L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2) ) # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - #initial state - - kalman_state = (0., jnp.zeros(d), jnp.zeros(d)) + #initial state of the kalman filter + kalman_state = (0., jnp.zeros(dim), jnp.zeros(dim)) # run the steps - kalman_state = jax.lax.scan(step, init= (state, params, params, adap0, kalman_state), xs= outer_weights, length= num_steps1 + num_steps2)[0] - state, _, params, _, kalman_state_output = kalman_state + kalman_state = jax.lax.scan( + step, + init= (state, params, adap0, kalman_state), + xs=(outer_weights, L_step_size_adaptation_keys), length= num_steps1 + num_steps2)[0] + state, params, _, kalman_state_output = kalman_state L = params.L # determine L @@ -171,47 +160,37 @@ def L_step_size_adaptation(state, params, num_steps): return L_step_size_adaptation - - def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + def adaptation_L(state, params, num_steps, key): - def sample_full(num_steps, state, params): - - def step(state, _): - state, _ = kernel(rng_key=jax.random.PRNGKey(0), state=state, L=params.L, step_size=params.step_size) - - return state, state.position - - return jax.lax.scan(step, init=state, xs=None, length=num_steps) - - - def adaptation_L(state, params, num_steps): - steps = jnp.rint(num_steps * frac).astype(int) + num_steps = int(num_steps * frac) + # TODO: change below to use jax.random.split + adaptation_L_keys = jnp.array([key]*num_steps) - state, samples = sample_full(steps, state, params) - num_samples = samples.shape[0] - ESS = 0.5 * effective_sample_size(jnp.array([samples, samples])) - ess_per_sample = ESS / num_samples - ESS = 1.0 / jnp.average(1 / ess_per_sample) - - Lnew = Lfactor * params.step_size / ESS # = 0.4 * length corresponding to one effective sample - - return state, MCLMCAdaptationState(Lnew, params.step_size) + # run kernel in the normal way + state, info = jax.lax.scan( + f=lambda s, k: (kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)), + init=state, + xs=adaptation_L_keys) + samples = info.transformed_x # tranform is the identity here + ESS = 0.5 * effective_sample_size(jnp.array([samples, samples])) # TODO: should only use a single chain here + return state, params._replace(L=Lfactor * params.step_size * jnp.average(num_steps / ESS)) return adaptation_L -def handle_nans(x, u, l, g, xx, uu, ll, gg, step_size, step_size_max, dK): +def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + + reduced_step_size = 0.8 + nonans = jnp.all(jnp.isfinite(state_new.position)) + state, step_size, kinetic_change = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (state_new, step_size_max, kinetic_change), + (state_old, step_size * reduced_step_size, 0.)) - nonans = jnp.all(jnp.isfinite(xx)) - x, u, l, g, step_size, dk = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), - (xx, uu, ll, gg, step_size_max, dK), - (x, u, l, g, step_size * 0.8, 0.)) - - return nonans, x, u, l, g, step_size, dk + return nonans, state, step_size, kinetic_change \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 21b4a2f2d..cf6114001 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -19,7 +19,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey -from blackjax.util import full_refresh, generate_unit_vector, partially_refresh_momentum +from blackjax.util import generate_unit_vector, partially_refresh_momentum __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -172,6 +172,7 @@ def __new__( # type: ignore[misc] L, step_size, integrator=noneuclidean_mclachlan, + seed=1, ) -> SamplingAlgorithm: kernel = cls.build_kernel(logdensity_fn, integrator, transform) @@ -179,6 +180,6 @@ def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) def init_fn(position: ArrayLike): - return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) + return cls.init(position, logdensity_fn, jax.random.PRNGKey(seed)) return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/util.py b/blackjax/util.py index 52c7c634c..7452d38ce 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -104,19 +104,6 @@ def generate_unit_vector( sample = normal(rng_key, shape=p.shape, dtype=p.dtype) return unravel_fn(sample / jnp.linalg.norm(sample)) -def full_refresh(d): - """Generates a random (isotropic) unit vector.""" - - - def rng(random_key): - key, subkey = jax.random.split(random_key) - u = jax.random.normal(jax.random.PRNGKey(0), shape = (d, )) - u /= jnp.sqrt(jnp.sum(jnp.square(u))) - return u, key - - - return rng - def partially_refresh_momentum(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. From c121bebbe4a3f0c203921fa0af2a46c51de46f29 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 00:58:23 +0100 Subject: [PATCH 61/78] clean up tuning --- blackjax/adaptation/mclmc_adaptation.py | 175 +++++++++++++++--------- blackjax/adaptation/step_size.py | 5 - blackjax/mcmc/integrators.py | 1 - blackjax/mcmc/mclmc.py | 15 +- blackjax/util.py | 2 +- 5 files changed, 118 insertions(+), 80 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index b1eb9fd0c..bd159cd3a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -19,12 +19,8 @@ import jax import jax.numpy as jnp -from blackjax.diagnostics import effective_sample_size #type: ignore -import jax -import jax.numpy as jnp -from typing import NamedTuple -from typing import NamedTuple +from blackjax.diagnostics import effective_sample_size # type: ignore from blackjax.util import pytree_size @@ -35,13 +31,21 @@ class MCLMCAdaptationState(NamedTuple): L (float): The momentum decoherent rate for the MCLMC algorithm. step_size (float): The step size used for the MCLMC algorithm. """ - + L: float step_size: float -def mclmc_find_L_and_step_size(kernel, num_steps, state, part1_key, part2_key, frac_tune1=0.1, + +def mclmc_find_L_and_step_size( + kernel, + num_steps, + state, + part1_key, + part2_key, + frac_tune1=0.1, frac_tune2=0.1, - frac_tune3=0.1, ): + frac_tune3=0.1, +): """ Finds the optimal value of L (step size) for the MCLMC algorithm. @@ -61,100 +65,131 @@ def mclmc_find_L_and_step_size(kernel, num_steps, state, part1_key, part2_key, f params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) varEwanted = 5e-4 - state, params = make_L_step_size_adaptation(kernel=kernel, dim=dim, frac_tune1=frac_tune1, frac_tune2=frac_tune2, varEwanted=varEwanted, sigma_xi=1.5, num_effective_samples=150)(state, params, num_steps, part1_key) + state, params = make_L_step_size_adaptation( + kernel=kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + varEwanted=varEwanted, + sigma_xi=1.5, + num_effective_samples=150, + )(state, params, num_steps, part1_key) if frac_tune3 != 0: - state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)(state,params, num_steps, part2_key) - - return state, params + state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)( + state, params, num_steps, part2_key + ) + return state, params -def make_L_step_size_adaptation(kernel, dim, frac_tune1, frac_tune2, - varEwanted = 1e-3, sigma_xi = 1.5, num_effective_samples = 150): +def make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + varEwanted=1e-3, + sigma_xi=1.5, + num_effective_samples=150, +): """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) - - + def predictor(state_old, params, adaptive_state, rng_key): """does one step with the dynamics and updates the prediction for the optimal stepsize - Designed for the unadjusted MCHMC""" - + Designed for the unadjusted MCHMC""" + W, F, step_size_max = adaptive_state # dynamics - state_new, info = kernel(rng_key = rng_key, state=state_old, L=params.L, step_size=params.step_size) + state_new, info = kernel( + rng_key=rng_key, state=state_old, L=params.L, step_size=params.step_size + ) energy_change = info.dE # step updating - success, state, step_size_max, energy_change = handle_nans(state_old,state_new, - params.step_size, step_size_max, energy_change) + success, state, step_size_max, energy_change = handle_nans( + state_old, state_new, params.step_size, step_size_max, energy_change + ) - # Warning: var = 0 if there were nans, but we will give it a very small weight - xi = (jnp.square(energy_change) / (dim * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi - w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. - - F = gamma_forget * F + w * (xi/jnp.power(params.step_size, 6.0)) + xi = ( + jnp.square(energy_change) / (dim * varEwanted) + ) + 1e-8 # 1e-8 is added to avoid divergences in log xi + w = jnp.exp( + -0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi)) + ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. + + F = gamma_forget * F + w * (xi / jnp.power(params.step_size, 6.0)) W = gamma_forget * W + w - step_size = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here. - step_size = (step_size < step_size_max) * step_size + (step_size > step_size_max) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + step_size = jnp.power( + F / W, -1.0 / 6.0 + ) # We use the Var[E] = O(eps^6) relation here. + step_size = (step_size < step_size_max) * step_size + ( + step_size > step_size_max + ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences params_new = params._replace(step_size=step_size) - - return state, params_new, params_new, (W, F, step_size_max), success + return state, params_new, params_new, (W, F, step_size_max), success def update_kalman(x, state, outer_weight, success, step_size): """kalman filter to estimate the size of the posterior""" W, F1, F2 = state w = outer_weight * step_size * success zero_prevention = 1 - outer_weight - F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update with a Kalman filter - F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update with a Kalman filter + F1 = (W * F1 + w * x) / ( + W + w + zero_prevention + ) # Update with a Kalman filter + F2 = (W * F2 + w * jnp.square(x)) / ( + W + w + zero_prevention + ) # Update with a Kalman filter W += w return (W, F1, F2) + adap0 = (0.0, 0.0, jnp.inf) - adap0 = (0., 0., jnp.inf) - - def step(iteration_state, weight_and_key): outer_weight, rng_key = weight_and_key """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" state, params, adaptive_state, kalman_state = iteration_state - state, params, params_final, adaptive_state, success = predictor(state, params, adaptive_state, rng_key) - kalman_state = update_kalman(state.position, kalman_state, outer_weight, success, params.step_size) + state, params, params_final, adaptive_state, success = predictor( + state, params, adaptive_state, rng_key + ) + kalman_state = update_kalman( + state.position, kalman_state, outer_weight, success, params.step_size + ) return (state, params_final, adaptive_state, kalman_state), None - def L_step_size_adaptation(state, params, num_steps, rng_key): - - num_steps1, num_steps2 = int(num_steps * frac_tune1), int(num_steps*frac_tune2) + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) # TODO: change below to use jax.random.split - L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2) ) - + L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2)) + # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - #initial state of the kalman filter - kalman_state = (0., jnp.zeros(dim), jnp.zeros(dim)) + # initial state of the kalman filter + kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim)) # run the steps kalman_state = jax.lax.scan( - step, - init= (state, params, adap0, kalman_state), - xs=(outer_weights, L_step_size_adaptation_keys), length= num_steps1 + num_steps2)[0] + step, + init=(state, params, adap0, kalman_state), + xs=(outer_weights, L_step_size_adaptation_keys), + length=num_steps1 + num_steps2, + )[0] state, params, _, kalman_state_output = kalman_state - + L = params.L # determine L - if num_steps2 != 0.: + if num_steps2 != 0.0: _, F1, F2 = kalman_state_output variances = F2 - jnp.square(F1) L = jnp.sqrt(jnp.sum(variances)) - return state, MCLMCAdaptationState(L, params.step_size) return L_step_size_adaptation @@ -162,22 +197,28 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - - def adaptation_L(state, params, num_steps, key): + def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) # TODO: change below to use jax.random.split - adaptation_L_keys = jnp.array([key]*num_steps) - + adaptation_L_keys = jnp.array([key] * num_steps) + # run kernel in the normal way state, info = jax.lax.scan( - f=lambda s, k: (kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)), - init=state, - xs=adaptation_L_keys) - samples = info.transformed_x # tranform is the identity here - ESS = 0.5 * effective_sample_size(jnp.array([samples, samples])) # TODO: should only use a single chain here - - return state, params._replace(L=Lfactor * params.step_size * jnp.average(num_steps / ESS)) + f=lambda s, k: ( + kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size) + ), + init=state, + xs=adaptation_L_keys, + ) + samples = info.transformed_x # tranform is the identity here + ESS = 0.5 * effective_sample_size( + jnp.array([samples, samples]) + ) # TODO: should only use a single chain here + + return state, params._replace( + L=Lfactor * params.step_size * jnp.average(num_steps / ESS) + ) return adaptation_L @@ -187,10 +228,10 @@ def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change): reduced_step_size = 0.8 nonans = jnp.all(jnp.isfinite(state_new.position)) - state, step_size, kinetic_change = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), - (state_new, step_size_max, kinetic_change), - (state_old, step_size * reduced_step_size, 0.)) - + state, step_size, kinetic_change = jax.tree_util.tree_map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (state_new, step_size_max, kinetic_change), + (state_old, step_size * reduced_step_size, 0.0), + ) + return nonans, state, step_size, kinetic_change - - \ No newline at end of file diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 6fbc7e924..2d6b0182f 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Step size adaptation""" -import warnings from typing import Callable, NamedTuple import jax import jax.numpy as jnp -from scipy.fft import next_fast_len from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.integrators import noneuclidean_mclachlan -from blackjax.mcmc.mclmc import IntegratorState, build_kernel, init from blackjax.optimizers.dual_averaging import dual_averaging from blackjax.types import PRNGKey @@ -261,4 +257,3 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size - diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 0f89bd77e..c84502517 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -368,4 +368,3 @@ def noneuclidean_integrator( noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) - diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index cf6114001..f5bc75453 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -47,12 +47,15 @@ class MCLMCInfo(NamedTuple): def init(x_initial: ArrayLike, logdensity_fn, rng_key): l, g = jax.value_and_grad(logdensity_fn)(x_initial) - jax.debug.print("thing blackjax {x}", x=IntegratorState( - position=x_initial, - momentum=generate_unit_vector(rng_key, x_initial), - logdensity=l, - logdensity_grad=g, - )) + jax.debug.print( + "thing blackjax {x}", + x=IntegratorState( + position=x_initial, + momentum=generate_unit_vector(rng_key, x_initial), + logdensity=l, + logdensity_grad=g, + ), + ) return IntegratorState( position=x_initial, diff --git a/blackjax/util.py b/blackjax/util.py index 7452d38ce..1c764137d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,7 +1,6 @@ """Utility functions for BlackJax.""" from functools import partial from typing import Union -import jax import jax.numpy as jnp from jax import jit, lax @@ -104,6 +103,7 @@ def generate_unit_vector( sample = normal(rng_key, shape=p.shape, dtype=p.dtype) return unravel_fn(sample / jnp.linalg.norm(sample)) + def partially_refresh_momentum(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. From c456efe69742d413ab937c8ec3579d122c13ddf8 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 01:23:14 +0100 Subject: [PATCH 62/78] RANDOMIZE KEYS --- blackjax/adaptation/mclmc_adaptation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 7e5a6586e..63f73cb0e 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -165,8 +165,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( num_steps * frac_tune2 ) - # TODO: change below to use jax.random.split - L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2)) + L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) @@ -200,8 +199,7 @@ def make_adaptation_L(kernel, frac, Lfactor): def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) - # TODO: change below to use jax.random.split - adaptation_L_keys = jnp.array([key] * num_steps) + adaptation_L_keys = jax.random.split(key, num_steps) # run kernel in the normal way state, info = jax.lax.scan( From d0a008ab6fec75ff558106fa36f7731b234ba459 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 02:03:28 +0100 Subject: [PATCH 63/78] ADD TEST --- blackjax/adaptation/__init__.py | 2 + blackjax/adaptation/mclmc_adaptation.py | 17 +++++-- tests/mcmc/test_sampling.py | 64 ++++++++++++------------- 3 files changed, 44 insertions(+), 39 deletions(-) diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index 91a491ed0..0b89c3793 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -3,6 +3,7 @@ meads_adaptation, pathfinder_adaptation, window_adaptation, + mclmc_adaptation, ) __all__ = [ @@ -10,4 +11,5 @@ "meads_adaptation", "window_adaptation", "pathfinder_adaptation", + "mclmc_adaptation" ] diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 63f73cb0e..5395e40c2 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -22,6 +22,8 @@ from blackjax.diagnostics import effective_sample_size # type: ignore from blackjax.util import pytree_size +from jax.flatten_util import ravel_pytree + class MCLMCAdaptationState(NamedTuple): @@ -155,8 +157,9 @@ def step(iteration_state, weight_and_key): state, params, params_final, adaptive_state, success = predictor( state, params, adaptive_state, rng_key ) + position, _ = ravel_pytree(state.position) kalman_state = update_kalman( - state.position, kalman_state, outer_weight, success, params.step_size + position, kalman_state, outer_weight, success, params.step_size ) return (state, params_final, adaptive_state, kalman_state), None @@ -165,7 +168,8 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( num_steps * frac_tune2 ) - L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) + # L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) + L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2)) # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) @@ -199,7 +203,8 @@ def make_adaptation_L(kernel, frac, Lfactor): def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) - adaptation_L_keys = jax.random.split(key, num_steps) + # adaptation_L_keys = jax.random.split(key, num_steps) + adaptation_L_keys = jnp.array([key] * (num_steps)) # run kernel in the normal way state, info = jax.lax.scan( @@ -210,8 +215,9 @@ def adaptation_L(state, params, num_steps, key): xs=adaptation_L_keys, ) samples = info.transformed_position # tranform is the identity here + flat_samples, unravel_fn = ravel_pytree(samples) ESS = 0.5 * effective_sample_size( - jnp.array([samples, samples]) + jnp.array([flat_samples, flat_samples]) ) # TODO: should only use a single chain here return state, params._replace( @@ -225,7 +231,8 @@ def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" reduced_step_size = 0.8 - nonans = jnp.all(jnp.isfinite(state_new.position)) + p, unravel_fn = ravel_pytree(state_new.position) + nonans = (jnp.all(jnp.isfinite(p))) state, step_size, kinetic_change = jax.tree_util.tree_map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (state_new, step_size_max, kinetic_change), diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 801fea6fd..84c7e589d 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -84,35 +84,31 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) - # def tune_and_run(position, logdensity_fn, key, dim, num_steps): - # main_key, tune_key = jax.random.split(key) - - # # params, state = tune( - # # position=position, - # # params=MCLMCAdaptationState(L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4), - # # logdensity_fn=logdensity_fn, - # # num_steps=num_steps, - # # rng_key=tune_key, - # # ) - # # print( - # # f"L is {params.L} and should be {1.3147894144058228} and step_size is {params.step_size} and should be {0.6470216512680054}" - # # ) - - # mclmc = blackjax.mcmc.mclmc.mclmc( - # logdensity_fn=logdensity_fn, - # transform=lambda x: x, - # # L=params.L, - # # step_size=params.step_size, - # L=math.sqrt(dim), step_size=math.sqrt(dim) * 0.4 - # ) - - # return run_sampling_algorithm( - # sampling_algorithm=mclmc, - # num_steps=num_steps, - # # initial_val=state.position, - # initial_val=position, - # rng_key=main_key, - # ) + def run_mclmc(self, logdensity_fn,num_steps, initial_position, key): + + init_key, part1_key, part2_key, run_key = jax.random.split(key, 4) + + initial_state = blackjax.mcmc.mclmc.init(x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=key) + + kernel = blackjax.mcmc.mclmc.build_kernel(logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan, transform=lambda x: x) + + blackjax_state_after_tuning, blackjax_mclmc_sampler_params = blackjax.adaptation.mclmc_adaptation.mclmc_find_L_and_step_size( + kernel=kernel, + num_steps=num_steps, + state=initial_state, + part1_key=key, + part2_key=key, + ) + + keys = jax.random.split(key, num_steps) + + + _, blackjax_mclmc_result = jax.lax.scan( + f=lambda state, key: kernel(L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, rng_key=key, state=state), + xs=keys, + init=blackjax_state_after_tuning) + + return blackjax_mclmc_result.transformed_position @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) def test_window_adaptation(self, case, is_mass_matrix_diagonal): @@ -184,13 +180,13 @@ def test_mclmc(self): ) logdensity_fn = lambda x: logposterior_fn_(**x) - states = tune_and_run(position={"coefs": 1.0, "log_scale": 1.0}, logdensity_fn=logdensity_fn, key=inference_key, dim=2, num_steps=10000) + states = self.run_mclmc(initial_position={"coefs": 1.0, "log_scale": 1.0}, logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000) - coefs_samples = states.transformed_position["coefs"][3000:] - scale_samples = np.exp(states.transformed_position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( From d692498aaddb98d56e7b1b8606c2603f77992422 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 02:28:09 +0100 Subject: [PATCH 64/78] ADD TEST --- blackjax/adaptation/__init__.py | 4 +- blackjax/adaptation/mclmc_adaptation.py | 9 +- blackjax/mcmc/rmchmc.py | 104 +++++++++++++----------- tests/mcmc/test_sampling.py | 42 +++++++--- 4 files changed, 92 insertions(+), 67 deletions(-) diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index 0b89c3793..53d5fe2b6 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -1,9 +1,9 @@ from . import ( chees_adaptation, + mclmc_adaptation, meads_adaptation, pathfinder_adaptation, window_adaptation, - mclmc_adaptation, ) __all__ = [ @@ -11,5 +11,5 @@ "meads_adaptation", "window_adaptation", "pathfinder_adaptation", - "mclmc_adaptation" + "mclmc_adaptation", ] diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 5395e40c2..e97992a1a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -19,11 +19,10 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size # type: ignore from blackjax.util import pytree_size -from jax.flatten_util import ravel_pytree - class MCLMCAdaptationState(NamedTuple): @@ -215,7 +214,9 @@ def adaptation_L(state, params, num_steps, key): xs=adaptation_L_keys, ) samples = info.transformed_position # tranform is the identity here - flat_samples, unravel_fn = ravel_pytree(samples) + flat_samples, _ = ravel_pytree(samples) + dim = pytree_size(state.position) + flat_samples = flat_samples.reshape(-1, dim) ESS = 0.5 * effective_sample_size( jnp.array([flat_samples, flat_samples]) ) # TODO: should only use a single chain here @@ -232,7 +233,7 @@ def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change): reduced_step_size = 0.8 p, unravel_fn = ravel_pytree(state_new.position) - nonans = (jnp.all(jnp.isfinite(p))) + nonans = jnp.all(jnp.isfinite(p)) state, step_size, kinetic_change = jax.tree_util.tree_map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (state_new, step_size_max, kinetic_change), diff --git a/blackjax/mcmc/rmchmc.py b/blackjax/mcmc/rmchmc.py index 907c1d97d..e9cf54b4c 100644 --- a/blackjax/mcmc/rmchmc.py +++ b/blackjax/mcmc/rmchmc.py @@ -23,101 +23,109 @@ __all__ = ["RMCHMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] -from mclmc import Parameters, MCLMCInfo, full_refresh, update_position, update_momentum, minimal_norm +from mclmc import ( + MCLMCInfo, + Parameters, + full_refresh, + minimal_norm, + update_momentum, + update_position, +) class RMCHMCState(NamedTuple): """State of the MCLMC algorithm.""" - - t: float # time step (0., 1., 2., ....) - x: Array # location in the sampling space - l: float # - log p(x) - g: Array # - grad log p(x) + t: float # time step (0., 1., 2., ....) + x: Array # location in the sampling space + l: float # - log p(x) + g: Array # - grad log p(x) -def init(x_initial : ArrayLikeTree, logdensity_fn): - grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) +def init(x_initial: ArrayLikeTree, logdensity_fn): + grad_nlogp = jax.value_and_grad(lambda x: -logdensity_fn(x)) l, g = grad_nlogp(x_initial) - return RMCHMCState(0., x_initial, l, g) - - + return RMCHMCState(0.0, x_initial, l, g) def halton(t, max_bits=10): """for t= 0., 1., 2., ... it outputs halton sequence at that index (0.5, 0.25, 0.75, ...) - taken from: https://github.com/tensorflow/probability/blob/main/discussion/snaper_hmc/SNAPER-HMC.ipynb""" + taken from: https://github.com/tensorflow/probability/blob/main/discussion/snaper_hmc/SNAPER-HMC.ipynb + """ float_index = jnp.asarray(t) - bit_masks = 2**jnp.arange(max_bits, dtype=float_index.dtype) - return jnp.einsum('i,i->', jnp.mod((float_index + 1) // bit_masks, 2), 0.5 / bit_masks) - + bit_masks = 2 ** jnp.arange(max_bits, dtype=float_index.dtype) + return jnp.einsum( + "i,i->", jnp.mod((float_index + 1) // bit_masks, 2), 0.5 / bit_masks + ) def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. """ - k = jnp.floor(2 * mu -1) - x = k * (mu - 0.5 *(k+1)) / (k + 1 - mu) + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) return k + x - + def trajectory_length(t, mu): s = rescale(mu) return jnp.rint(0.5 + halton(t) * s) - def proposal(hamiltonian_step, d): - def prop(t, x, g, random_key, L, eps, sigma): - - #jiter the number of steps + # jiter the number of steps num_steps = jnp.rint(2 * halton(t) * L / eps).astype(int) - - #full momentum refreshment + + # full momentum refreshment u = full_refresh(random_key, d) # do num_steps of the Hamiltonian dynamics def body(i, state): - x, u, l, g, kinetic_energy = state - xx, uu, ll, gg, kinetic_change = hamiltonian_step(x=x, u=u, g=g, eps=eps, sigma = sigma) + xx, uu, ll, gg, kinetic_change = hamiltonian_step( + x=x, u=u, g=g, eps=eps, sigma=sigma + ) return xx, uu, ll, gg, kinetic_energy + kinetic_change - - xx, uu, ll, gg, kinetic_change = jax.fori_loop(0, num_steps, body, (x, u, 0., g, 0.)) - + + xx, uu, ll, gg, kinetic_change = jax.fori_loop( + 0, num_steps, body, (x, u, 0.0, g, 0.0) + ) + return xx, ll, gg, kinetic_change return prop def build_kernel(grad_nlogp, d, integrator, transform, params): - L, eps, sigma = params - hamiltonian_step, _ = integrator(T= update_position(grad_nlogp), V= update_momentum(d), d= d) + hamiltonian_step, _ = integrator( + T=update_position(grad_nlogp), V=update_momentum(d), d=d + ) get_proposal = proposal(hamiltonian_step, d) - - def kernel(rng_key : PRNGKey, state : RMCHMCState) -> tuple[RMCHMCState, MCLMCInfo]: - + + def kernel(rng_key: PRNGKey, state: RMCHMCState) -> tuple[RMCHMCState, MCLMCInfo]: key1, key2 = jax.random.split(rng_key) - + t, x, l, g = state xx, ll, gg, kinetic_change = get_proposal(t, x, g, key1, L, eps, sigma) de = kinetic_change + ll - l - + # accept/reject acc_prob = jnp.clip(jnp.exp(-de), 0, 1) accept = jax.random.bernoulli(key2, acc_prob) - xx, ll, gg = jax.tree_util.tree_map(lambda new, old: jax.lax.select(accept, new, old), (xx, ll, gg), (x, l, g)) - - return RMCHMCState(t + 1., xx, ll, gg), MCLMCInfo(transform(xx), ll, de) + xx, ll, gg = jax.tree_util.tree_map( + lambda new, old: jax.lax.select(accept, new, old), (xx, ll, gg), (x, l, g) + ) + + return RMCHMCState(t + 1.0, xx, ll, gg), MCLMCInfo(transform(xx), ll, de) return kernel @@ -131,14 +139,13 @@ class rmchmc: def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - d : int, - transform : Callable, - params : Parameters, + d: int, + transform: Callable, + params: Parameters, *, - integrator = minimal_norm, + integrator=minimal_norm, ) -> SamplingAlgorithm: - - grad_nlogp = jax.value_and_grad(lambda x : - logdensity_fn(x)) + grad_nlogp = jax.value_and_grad(lambda x: -logdensity_fn(x)) kernel = cls.build_kernel(grad_nlogp, d, integrator, transform, params) @@ -152,4 +159,3 @@ def step_fn(rng_key: PRNGKey, state): ) return SamplingAlgorithm(init_fn, step_fn) - diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 84c7e589d..58dbea645 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -84,15 +84,23 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) - def run_mclmc(self, logdensity_fn,num_steps, initial_position, key): - + def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): init_key, part1_key, part2_key, run_key = jax.random.split(key, 4) - - initial_state = blackjax.mcmc.mclmc.init(x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=key) - - kernel = blackjax.mcmc.mclmc.build_kernel(logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan, transform=lambda x: x) - blackjax_state_after_tuning, blackjax_mclmc_sampler_params = blackjax.adaptation.mclmc_adaptation.mclmc_find_L_and_step_size( + initial_state = blackjax.mcmc.mclmc.init( + x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=key + ) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan, + transform=lambda x: x, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adaptation.mclmc_adaptation.mclmc_find_L_and_step_size( kernel=kernel, num_steps=num_steps, state=initial_state, @@ -102,11 +110,16 @@ def run_mclmc(self, logdensity_fn,num_steps, initial_position, key): keys = jax.random.split(key, num_steps) - _, blackjax_mclmc_result = jax.lax.scan( - f=lambda state, key: kernel(L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, rng_key=key, state=state), - xs=keys, - init=blackjax_state_after_tuning) + f=lambda state, key: kernel( + L=blackjax_mclmc_sampler_params.L, + step_size=blackjax_mclmc_sampler_params.step_size, + rng_key=key, + state=state, + ), + xs=keys, + init=blackjax_state_after_tuning, + ) return blackjax_mclmc_result.transformed_position @@ -180,7 +193,12 @@ def test_mclmc(self): ) logdensity_fn = lambda x: logposterior_fn_(**x) - states = self.run_mclmc(initial_position={"coefs": 1.0, "log_scale": 1.0}, logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000) + states = self.run_mclmc( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) coefs_samples = states["coefs"][3000:] scale_samples = np.exp(states["log_scale"][3000:]) From a45f58fac9192417895abac649167c2c6fa8f2bc Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 02:30:51 +0100 Subject: [PATCH 65/78] MERGE MAIN --- blackjax/mcmc/rmchmc.py | 161 ---------------------------------------- 1 file changed, 161 deletions(-) delete mode 100644 blackjax/mcmc/rmchmc.py diff --git a/blackjax/mcmc/rmchmc.py b/blackjax/mcmc/rmchmc.py deleted file mode 100644 index e9cf54b4c..000000000 --- a/blackjax/mcmc/rmchmc.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright 2020- The Blackjax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Public API for the MCLMC Kernel""" -from typing import Callable, NamedTuple - -import jax -import jax.numpy as jnp - -from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, PRNGKey - -__all__ = ["RMCHMCState", "MCLMCInfo", "init", "build_kernel", "mclmc"] - - -from mclmc import ( - MCLMCInfo, - Parameters, - full_refresh, - minimal_norm, - update_momentum, - update_position, -) - - -class RMCHMCState(NamedTuple): - """State of the MCLMC algorithm.""" - - t: float # time step (0., 1., 2., ....) - x: Array # location in the sampling space - l: float # - log p(x) - g: Array # - grad log p(x) - - -def init(x_initial: ArrayLikeTree, logdensity_fn): - grad_nlogp = jax.value_and_grad(lambda x: -logdensity_fn(x)) - l, g = grad_nlogp(x_initial) - - return RMCHMCState(0.0, x_initial, l, g) - - -def halton(t, max_bits=10): - """for t= 0., 1., 2., ... it outputs halton sequence at that index (0.5, 0.25, 0.75, ...) - taken from: https://github.com/tensorflow/probability/blob/main/discussion/snaper_hmc/SNAPER-HMC.ipynb - """ - float_index = jnp.asarray(t) - bit_masks = 2 ** jnp.arange(max_bits, dtype=float_index.dtype) - return jnp.einsum( - "i,i->", jnp.mod((float_index + 1) // bit_masks, 2), 0.5 / bit_masks - ) - - -def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. - """ - k = jnp.floor(2 * mu - 1) - x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) - return k + x - - -def trajectory_length(t, mu): - s = rescale(mu) - return jnp.rint(0.5 + halton(t) * s) - - -def proposal(hamiltonian_step, d): - def prop(t, x, g, random_key, L, eps, sigma): - # jiter the number of steps - num_steps = jnp.rint(2 * halton(t) * L / eps).astype(int) - - # full momentum refreshment - u = full_refresh(random_key, d) - - # do num_steps of the Hamiltonian dynamics - - def body(i, state): - x, u, l, g, kinetic_energy = state - xx, uu, ll, gg, kinetic_change = hamiltonian_step( - x=x, u=u, g=g, eps=eps, sigma=sigma - ) - - return xx, uu, ll, gg, kinetic_energy + kinetic_change - - xx, uu, ll, gg, kinetic_change = jax.fori_loop( - 0, num_steps, body, (x, u, 0.0, g, 0.0) - ) - - return xx, ll, gg, kinetic_change - - return prop - - -def build_kernel(grad_nlogp, d, integrator, transform, params): - L, eps, sigma = params - - hamiltonian_step, _ = integrator( - T=update_position(grad_nlogp), V=update_momentum(d), d=d - ) - get_proposal = proposal(hamiltonian_step, d) - - def kernel(rng_key: PRNGKey, state: RMCHMCState) -> tuple[RMCHMCState, MCLMCInfo]: - key1, key2 = jax.random.split(rng_key) - - t, x, l, g = state - xx, ll, gg, kinetic_change = get_proposal(t, x, g, key1, L, eps, sigma) - de = kinetic_change + ll - l - - # accept/reject - - acc_prob = jnp.clip(jnp.exp(-de), 0, 1) - accept = jax.random.bernoulli(key2, acc_prob) - xx, ll, gg = jax.tree_util.tree_map( - lambda new, old: jax.lax.select(accept, new, old), (xx, ll, gg), (x, l, g) - ) - - return RMCHMCState(t + 1.0, xx, ll, gg), MCLMCInfo(transform(xx), ll, de) - - return kernel - - -class rmchmc: - """todo: add documentation""" - - init = staticmethod(init) - build_kernel = staticmethod(build_kernel) - - def __new__( # type: ignore[misc] - cls, - logdensity_fn: Callable, - d: int, - transform: Callable, - params: Parameters, - *, - integrator=minimal_norm, - ) -> SamplingAlgorithm: - grad_nlogp = jax.value_and_grad(lambda x: -logdensity_fn(x)) - - kernel = cls.build_kernel(grad_nlogp, d, integrator, transform, params) - - def init_fn(position: ArrayLikeTree): - return cls.init(position, logdensity_fn) - - def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - ) - - return SamplingAlgorithm(init_fn, step_fn) From 2a21c563ab0d9252bee9f49d514dd95a2e375890 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 02:43:19 +0100 Subject: [PATCH 66/78] INCREASE CODE COVERAGE --- blackjax/mcmc/mclmc.py | 2 +- tests/mcmc/test_sampling.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index cfa3bb1f5..a7df67a45 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -161,9 +161,9 @@ class mclmc: def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - transform: Callable, L, step_size, + transform: Callable = (lambda x: x), integrator=noneuclidean_mclachlan, seed=1, ) -> SamplingAlgorithm: diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 44b83159a..4f897efe0 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -110,10 +110,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): keys = jax.random.split(key, num_steps) + sampling_alg = blackjax.mcmc.mclmc.mclmc( + logdensity_fn, + L=blackjax_mclmc_sampler_params.L, + step_size=blackjax_mclmc_sampler_params.step_size, + ) + _, blackjax_mclmc_result = jax.lax.scan( - f=lambda state, key: kernel( - L=blackjax_mclmc_sampler_params.L, - step_size=blackjax_mclmc_sampler_params.step_size, + f=lambda state, key: sampling_alg.step( rng_key=key, state=state, ), From 67f0de987c066f806af420944e6f2390c76a08b5 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 11:52:20 +0100 Subject: [PATCH 67/78] REMOVE REDUNDANT LINE --- blackjax/mcmc/integrators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index c84502517..840693f81 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -1,4 +1,3 @@ -# @title `integrators.py` from https://github.com/blackjax-devs/blackjax/pull/589 # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); From 3f55f5fe20d51980b328e6b908784478a8de83c1 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 12:14:07 +0100 Subject: [PATCH 68/78] ADD NAME 'mclmc' --- blackjax/__init__.py | 2 ++ tests/mcmc/test_sampling.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index d0c6b8c7a..f6068a9b6 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -4,6 +4,7 @@ from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation +from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc.barker import barker_proposal @@ -14,6 +15,7 @@ from .mcmc.marginal_latent_gaussian import mgrad_gaussian from .mcmc.mclmc import mclmc from .mcmc.nuts import nuts +from .mcmc.mclmc import mclmc from .mcmc.periodic_orbital import orbital_hmc from .mcmc.random_walk import additive_step_random_walk, irmh, rmh from .optimizers import dual_averaging, lbfgs diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 4f897efe0..969f58c2e 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -100,7 +100,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, - ) = blackjax.adaptation.mclmc_adaptation.mclmc_find_L_and_step_size( + ) = blackjax.mclmc_find_L_and_step_size( kernel=kernel, num_steps=num_steps, state=initial_state, @@ -110,7 +110,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): keys = jax.random.split(key, num_steps) - sampling_alg = blackjax.mcmc.mclmc.mclmc( + sampling_alg = blackjax.mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, From 666c540384c88ee313988c6cb12a8959caba326b Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 13:10:16 +0100 Subject: [PATCH 69/78] SPLIT KEYS AND FIX DOCSTRING --- blackjax/adaptation/mclmc_adaptation.py | 12 ++++++------ tests/mcmc/test_sampling.py | 7 +++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index e97992a1a..a57d1fd4d 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -38,17 +38,16 @@ class MCLMCAdaptationState(NamedTuple): def mclmc_find_L_and_step_size( - kernel, + mclmc_kernel, num_steps, state, - part1_key, - part2_key, + rng_key, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1, ): """ - Finds the optimal value of L (step size) for the MCLMC algorithm. + Finds the optimal value of the parameters for the MCLMC algorithm. Args: kernel: The kernel function used for the MCMC algorithm. @@ -65,9 +64,10 @@ def mclmc_find_L_and_step_size( dim = pytree_size(state.position) params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) varEwanted = 5e-4 + part1_key, part2_key = jax.random.split(rng_key, 2) state, params = make_L_step_size_adaptation( - kernel=kernel, + kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, frac_tune2=frac_tune2, @@ -77,7 +77,7 @@ def mclmc_find_L_and_step_size( )(state, params, num_steps, part1_key) if frac_tune3 != 0: - state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)( + state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( state, params, num_steps, part2_key ) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 969f58c2e..1b6f5b982 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -85,7 +85,7 @@ def regression_logprob(self, log_scale, coefs, preds, x): return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): - init_key, part1_key, part2_key, run_key = jax.random.split(key, 4) + init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=key @@ -101,11 +101,10 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): blackjax_state_after_tuning, blackjax_mclmc_sampler_params, ) = blackjax.mclmc_find_L_and_step_size( - kernel=kernel, + mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, - part1_key=key, - part2_key=key, + rng_key=tune_key, ) keys = jax.random.split(key, num_steps) From c1615f59d130b47b91868d1037410e83d515cbbd Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 21:34:50 +0100 Subject: [PATCH 70/78] FIX MINOR ERRORS --- blackjax/__init__.py | 2 +- blackjax/adaptation/mclmc_adaptation.py | 5 ++- blackjax/mcmc/mclmc.py | 42 ++++++++++++++++++++----- blackjax/util.py | 25 --------------- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index f6068a9b6..95e6d04f5 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -13,7 +13,6 @@ from .mcmc.hmc import dynamic_hmc, hmc from .mcmc.mala import mala from .mcmc.marginal_latent_gaussian import mgrad_gaussian -from .mcmc.mclmc import mclmc from .mcmc.nuts import nuts from .mcmc.mclmc import mclmc from .mcmc.periodic_orbital import orbital_hmc @@ -54,6 +53,7 @@ "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", + "mclmc_find_L_and_step_size" # mclmc adaptation "adaptive_tempered_smc", # smc "tempered_smc", "meanfield_vi", # variational inference diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index a57d1fd4d..1846eb1b3 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -107,10 +107,9 @@ def predictor(state_old, params, adaptive_state, rng_key): state_new, info = kernel( rng_key=rng_key, state=state_old, L=params.L, step_size=params.step_size ) - energy_change = info.dE # step updating success, state, step_size_max, energy_change = handle_nans( - state_old, state_new, params.step_size, step_size_max, energy_change + state_old, state_new, params.step_size, step_size_max, info.energy_change ) # Warning: var = 0 if there were nans, but we will give it a very small weight @@ -222,7 +221,7 @@ def adaptation_L(state, params, num_steps, key): ) # TODO: should only use a single chain here return state, params._replace( - L=Lfactor * params.step_size * jnp.average(num_steps / ESS) + L=Lfactor * params.step_size * jnp.mean(num_steps / ESS) ) return adaptation_L diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index a7df67a45..d4abae92f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,11 +15,13 @@ from typing import Callable, NamedTuple import jax - +import jax.numpy as jnp from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey -from blackjax.util import generate_unit_vector, partially_refresh_momentum +from blackjax.util import generate_unit_vector, pytree_size +from jax.flatten_util import ravel_pytree +from jax.random import normal __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -34,14 +36,14 @@ class MCLMCInfo(NamedTuple): The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace. logdensity : The log-density of the distribution at the current step of the MCLMC chain. - dE : + energy_change : The difference in energy between the current and previous step. """ transformed_position: Array logdensity: float kinetic_change: float - dE: float + energy_change: float def init(x_initial: ArrayLike, logdensity_fn, rng_key): @@ -85,10 +87,9 @@ def kernel( state, step_size ) - # dim = position.shape[0] - dim = 2 - # Langevin-like noise + dim = pytree_size(position) + # Langevin-like noise momentum, dim = partially_refresh_momentum( momentum=momentum, rng_key=rng_key, L=L, step_size=step_size ) @@ -98,7 +99,7 @@ def kernel( ), MCLMCInfo( transformed_position=transform(position), logdensity=logdensity, - dE=kinetic_change - logdensity + state.logdensity, + energy_change=kinetic_change - logdensity + state.logdensity, kinetic_change=kinetic_change * (dim - 1), ) @@ -176,3 +177,28 @@ def init_fn(position: ArrayLike): return cls.init(position, logdensity_fn, jax.random.PRNGKey(seed)) return SamplingAlgorithm(init_fn, update_fn) + + +def partially_refresh_momentum(momentum, rng_key, step_size, L): + """Adds a small noise to momentum and normalizes. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + momentum: + PyTree that the structure the output should to match. + step_size: + Step size + L: + controls rate of momentum change + + Returns + ------- + momentum with random change in angle + """ + m, unravel_fn = ravel_pytree(momentum) + dim = m.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + return unravel_fn((m + z) / jnp.sqrt(jnp.sum(jnp.square(m + z)))), dim diff --git a/blackjax/util.py b/blackjax/util.py index 1c764137d..a3a7226a6 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -104,31 +104,6 @@ def generate_unit_vector( return unravel_fn(sample / jnp.linalg.norm(sample)) -def partially_refresh_momentum(momentum, rng_key, step_size, L): - """Adds a small noise to momentum and normalizes. - - Parameters - ---------- - rng_key: - The pseudo-random number generator key used to generate random numbers. - momentum: - PyTree that the structure the output should to match. - step_size: - Step size - L: - controls rate of momentum change - - Returns - ------- - momentum with random change in angle - """ - m, unravel_fn = ravel_pytree(momentum) - dim = m.shape[0] - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) - z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.sqrt(jnp.sum(jnp.square(m + z)))), dim - - def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) From ae1bf3054933bc4ff259096463d12e0563e67bdf Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 21:36:25 +0100 Subject: [PATCH 71/78] FIX MINOR ERRORS --- tests/mcmc/test_sampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 1b6f5b982..db9ef9944 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -88,7 +88,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( - x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=key + x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) kernel = blackjax.mcmc.mclmc.build_kernel( @@ -107,7 +107,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): rng_key=tune_key, ) - keys = jax.random.split(key, num_steps) + keys = jax.random.split(run_key, num_steps) sampling_alg = blackjax.mclmc( logdensity_fn, @@ -116,8 +116,8 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): ) _, blackjax_mclmc_result = jax.lax.scan( - f=lambda state, key: sampling_alg.step( - rng_key=key, + f=lambda state, k: sampling_alg.step( + rng_key=k, state=state, ), xs=keys, From 0902a1c27a0311c7043619444fb5d04af91e85d8 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 23:07:18 +0100 Subject: [PATCH 72/78] RANDOMIZE KEYS (reversion) --- blackjax/adaptation/mclmc_adaptation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 1846eb1b3..b19bad9ff 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -166,8 +166,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( num_steps * frac_tune2 ) - # L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) - L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2)) + L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) @@ -201,8 +200,7 @@ def make_adaptation_L(kernel, frac, Lfactor): def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) - # adaptation_L_keys = jax.random.split(key, num_steps) - adaptation_L_keys = jnp.array([key] * (num_steps)) + adaptation_L_keys = jax.random.split(key, num_steps) # run kernel in the normal way state, info = jax.lax.scan( From 2e3c80bfb5f1d1601c889333deac27d753357333 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 2 Dec 2023 23:13:10 +0100 Subject: [PATCH 73/78] PRECOMMIT CLEAN UP --- blackjax/__init__.py | 6 +++--- blackjax/mcmc/mclmc.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 95e6d04f5..ea49bacbd 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -1,10 +1,10 @@ from blackjax._version import __version__ from .adaptation.chees_adaptation import chees_adaptation +from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation -from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc.barker import barker_proposal @@ -13,8 +13,8 @@ from .mcmc.hmc import dynamic_hmc, hmc from .mcmc.mala import mala from .mcmc.marginal_latent_gaussian import mgrad_gaussian -from .mcmc.nuts import nuts from .mcmc.mclmc import mclmc +from .mcmc.nuts import nuts from .mcmc.periodic_orbital import orbital_hmc from .mcmc.random_walk import additive_step_random_walk, irmh, rmh from .optimizers import dual_averaging, lbfgs @@ -53,7 +53,7 @@ "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", - "mclmc_find_L_and_step_size" # mclmc adaptation + "mclmc_find_L_and_step_size", # mclmc adaptation "adaptive_tempered_smc", # smc "tempered_smc", "meanfield_vi", # variational inference diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index d4abae92f..17289d8c7 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -16,12 +16,13 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree +from jax.random import normal + from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import Array, ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size -from jax.flatten_util import ravel_pytree -from jax.random import normal __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] From 604b5a9e5f08dc73965331a4489d3036c882010d Mon Sep 17 00:00:00 2001 From: = Date: Sun, 3 Dec 2023 22:40:10 +0100 Subject: [PATCH 74/78] ADD KWARGS FOR DEFAULT HYPERPARAMS --- blackjax/adaptation/mclmc_adaptation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index b19bad9ff..033403643 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -45,6 +45,9 @@ def mclmc_find_L_and_step_size( frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1, + varEwanted=5e-4, + sigma_xi=1.5, + num_effective_samples=150, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -72,8 +75,8 @@ def mclmc_find_L_and_step_size( frac_tune1=frac_tune1, frac_tune2=frac_tune2, varEwanted=varEwanted, - sigma_xi=1.5, - num_effective_samples=150, + sigma_xi=sigma_xi, + num_effective_samples=num_effective_samples, )(state, params, num_steps, part1_key) if frac_tune3 != 0: From 50a82430ed9d0acad20ae548ff09868fc51d6d97 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 5 Dec 2023 13:07:41 +0100 Subject: [PATCH 75/78] UPDATE ESS --- blackjax/adaptation/mclmc_adaptation.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 033403643..59b2f52fe 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -214,12 +214,9 @@ def adaptation_L(state, params, num_steps, key): xs=adaptation_L_keys, ) samples = info.transformed_position # tranform is the identity here - flat_samples, _ = ravel_pytree(samples) - dim = pytree_size(state.position) - flat_samples = flat_samples.reshape(-1, dim) - ESS = 0.5 * effective_sample_size( - jnp.array([flat_samples, flat_samples]) - ) # TODO: should only use a single chain here + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + flat_samples = flat_samples.reshape(2, num_steps // 2, -1) + ESS = effective_sample_size(flat_samples) return state, params._replace( L=Lfactor * params.step_size * jnp.mean(num_steps / ESS) From a20a681ae7309d90033c8a713b02cce6256e2907 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 5 Dec 2023 13:45:57 +0100 Subject: [PATCH 76/78] NAME CHANGES --- blackjax/adaptation/mclmc_adaptation.py | 50 +++++++++++++------------ 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 59b2f52fe..2ece4fa26 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -45,8 +45,8 @@ def mclmc_find_L_and_step_size( frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1, - varEwanted=5e-4, - sigma_xi=1.5, + desired_energy_var=5e-4, + trust_in_estimate=1.5, num_effective_samples=150, ): """ @@ -66,7 +66,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) - varEwanted = 5e-4 + desired_energy_var = 5e-4 part1_key, part2_key = jax.random.split(rng_key, 2) state, params = make_L_step_size_adaptation( @@ -74,8 +74,8 @@ def mclmc_find_L_and_step_size( dim=dim, frac_tune1=frac_tune1, frac_tune2=frac_tune2, - varEwanted=varEwanted, - sigma_xi=sigma_xi, + desired_energy_var=desired_energy_var, + trust_in_estimate=trust_in_estimate, num_effective_samples=num_effective_samples, )(state, params, num_steps, part1_key) @@ -92,19 +92,19 @@ def make_L_step_size_adaptation( dim, frac_tune1, frac_tune2, - varEwanted=1e-3, - sigma_xi=1.5, + desired_energy_var=1e-3, + trust_in_estimate=1.5, num_effective_samples=150, ): """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" - gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) def predictor(state_old, params, adaptive_state, rng_key): """does one step with the dynamics and updates the prediction for the optimal stepsize Designed for the unadjusted MCHMC""" - W, F, step_size_max = adaptive_state + time, x_average, step_size_max = adaptive_state # dynamics state_new, info = kernel( @@ -117,37 +117,39 @@ def predictor(state_old, params, adaptive_state, rng_key): # Warning: var = 0 if there were nans, but we will give it a very small weight xi = ( - jnp.square(energy_change) / (dim * varEwanted) + jnp.square(energy_change) / (dim * desired_energy_var) ) + 1e-8 # 1e-8 is added to avoid divergences in log xi - w = jnp.exp( - -0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi)) + weight = jnp.exp( + -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate)) ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. - F = gamma_forget * F + w * (xi / jnp.power(params.step_size, 6.0)) - W = gamma_forget * W + w + x_average = decay_rate * x_average + weight * ( + xi / jnp.power(params.step_size, 6.0) + ) + time = decay_rate * time + weight step_size = jnp.power( - F / W, -1.0 / 6.0 + x_average / time, -1.0 / 6.0 ) # We use the Var[E] = O(eps^6) relation here. step_size = (step_size < step_size_max) * step_size + ( step_size > step_size_max ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences params_new = params._replace(step_size=step_size) - return state, params_new, params_new, (W, F, step_size_max), success + return state, params_new, params_new, (time, x_average, step_size_max), success def update_kalman(x, state, outer_weight, success, step_size): """kalman filter to estimate the size of the posterior""" - W, F1, F2 = state - w = outer_weight * step_size * success + time, x_average, x_squared_average = state + weight = outer_weight * step_size * success zero_prevention = 1 - outer_weight - F1 = (W * F1 + w * x) / ( - W + w + zero_prevention + x_average = (time * x_average + weight * x) / ( + time + weight + zero_prevention ) # Update with a Kalman filter - F2 = (W * F2 + w * jnp.square(x)) / ( - W + w + zero_prevention + x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / ( + time + weight + zero_prevention ) # Update with a Kalman filter - W += w - return (W, F1, F2) + time += weight + return (time, x_average, x_squared_average) adap0 = (0.0, 0.0, jnp.inf) From 75e71de191818dc9564b2d1049ef46c72d2dd530 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 5 Dec 2023 13:50:20 +0100 Subject: [PATCH 77/78] NAME CHANGES --- blackjax/adaptation/mclmc_adaptation.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 2ece4fa26..a165cd78d 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -100,19 +100,26 @@ def make_L_step_size_adaptation( decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) - def predictor(state_old, params, adaptive_state, rng_key): + def predictor(previous_state, params, adaptive_state, rng_key): """does one step with the dynamics and updates the prediction for the optimal stepsize Designed for the unadjusted MCHMC""" time, x_average, step_size_max = adaptive_state # dynamics - state_new, info = kernel( - rng_key=rng_key, state=state_old, L=params.L, step_size=params.step_size + next_state, info = kernel( + rng_key=rng_key, + state=previous_state, + L=params.L, + step_size=params.step_size, ) # step updating success, state, step_size_max, energy_change = handle_nans( - state_old, state_new, params.step_size, step_size_max, info.energy_change + previous_state, + next_state, + params.step_size, + step_size_max, + info.energy_change, ) # Warning: var = 0 if there were nans, but we will give it a very small weight @@ -227,16 +234,16 @@ def adaptation_L(state, params, num_steps, key): return adaptation_L -def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change): +def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" reduced_step_size = 0.8 - p, unravel_fn = ravel_pytree(state_new.position) + p, unravel_fn = ravel_pytree(next_state.position) nonans = jnp.all(jnp.isfinite(p)) state, step_size, kinetic_change = jax.tree_util.tree_map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), - (state_new, step_size_max, kinetic_change), - (state_old, step_size * reduced_step_size, 0.0), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * reduced_step_size, 0.0), ) return nonans, state, step_size, kinetic_change From 70f1dd5c1b50fc7db3e7de23d535928c04771a86 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 5 Dec 2023 14:06:15 +0100 Subject: [PATCH 78/78] MINOR FIXES --- blackjax/adaptation/mclmc_adaptation.py | 51 ++++++++++++++++++++----- blackjax/mcmc/mclmc.py | 2 +- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index a165cd78d..44a2944fc 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -53,20 +53,50 @@ def mclmc_find_L_and_step_size( Finds the optimal value of the parameters for the MCLMC algorithm. Args: - kernel: The kernel function used for the MCMC algorithm. - num_steps: The number of MCMC steps that will subsequently be run, after tuning - state: The initial state of the MCMC algorithm. - frac_tune1: The fraction of tuning for the first step of the adaptation. - frac_tune2: The fraction of tuning for the second step of the adaptation. - frac_tune3: The fraction of tuning for the third step of the adaptation. + mclmc_kernel (callable): The kernel function used for the MCMC algorithm. + num_steps (int): The number of MCMC steps that will subsequently be run, after tuning. + state (MCMCState): The initial state of the MCMC algorithm. + rng_key (jax.random.PRNGKey): The random number generator key. + frac_tune1 (float): The fraction of tuning for the first step of the adaptation. + frac_tune2 (float): The fraction of tuning for the second step of the adaptation. + frac_tune3 (float): The fraction of tuning for the third step of the adaptation. + desired_energy_var (float): The desired energy variance for the MCMC algorithm. + trust_in_estimate (float): The trust in the estimate of optimal stepsize. + num_effective_samples (int): The number of effective samples for the MCMC algorithm. Returns: - state: The final state of the MCMC algorithm. - params: The final hyperparameters of the MCMC algorithm. + tuple: A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + + Raises: + None + + Examples: + # Define the kernel function + def kernel(x): + return x ** 2 + + # Define the initial state + initial_state = MCMCState(position=0, momentum=1) + + # Generate a random number generator key + rng_key = jax.random.PRNGKey(0) + + # Find the optimal parameters for the MCLMC algorithm + final_state, final_params = mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=1000, + state=initial_state, + rng_key=rng_key, + frac_tune1=0.2, + frac_tune2=0.3, + frac_tune3=0.1, + desired_energy_var=1e-4, + trust_in_estimate=2.0, + num_effective_samples=200, + ) """ dim = pytree_size(state.position) params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) - desired_energy_var = 5e-4 part1_key, part2_key = jax.random.split(rng_key, 2) state, params = make_L_step_size_adaptation( @@ -161,8 +191,9 @@ def update_kalman(x, state, outer_weight, success, step_size): adap0 = (0.0, 0.0, jnp.inf) def step(iteration_state, weight_and_key): - outer_weight, rng_key = weight_and_key """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + outer_weight, rng_key = weight_and_key state, params, adaptive_state, kalman_state = iteration_state state, params, params_final, adaptive_state, success = predictor( state, params, adaptive_state, rng_key diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 17289d8c7..a84bcaa44 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -202,4 +202,4 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.sqrt(jnp.sum(jnp.square(m + z)))), dim + return unravel_fn((m + z) / jnp.linalg.norm(m + z)), dim