From 2acfbf6c318fdc385e238c67f49b3e79751b6b8f Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Fri, 2 Jun 2023 17:36:13 +0100 Subject: [PATCH 1/2] MCMCSamplingAlgorithm -> SamplingAlgorithm --- blackjax/base.py | 2 +- blackjax/mcmc/elliptical_slice.py | 8 ++++---- blackjax/mcmc/ghmc.py | 8 ++++---- blackjax/mcmc/hmc.py | 8 ++++---- blackjax/mcmc/mala.py | 8 ++++---- blackjax/mcmc/marginal_latent_gaussian.py | 8 ++++---- blackjax/mcmc/nuts.py | 8 ++++---- blackjax/mcmc/periodic_orbital.py | 8 ++++---- blackjax/mcmc/random_walk.py | 22 +++++++++++----------- blackjax/sgmcmc/csgld.py | 8 ++++---- blackjax/sgmcmc/sgnht.py | 8 ++++---- blackjax/smc/adaptive_tempered.py | 8 ++++---- blackjax/smc/tempered.py | 8 ++++---- blackjax/vi/svgd.py | 6 +++--- 14 files changed, 59 insertions(+), 59 deletions(-) diff --git a/blackjax/base.py b/blackjax/base.py index 922746423..0ad6a1628 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -85,7 +85,7 @@ def __call__(self, rng_key: PRNGKey, state: State) -> Tuple[State, Info]: """ -class MCMCSamplingAlgorithm(NamedTuple): +class SamplingAlgorithm(NamedTuple): """A pair of functions that represents a MCMC sampling algorithm. Blackjax sampling algorithms are implemented as a pair of pure functions: a diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index 98ae74d4b..e8010ffb5 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise @@ -149,7 +149,7 @@ class elliptical_slice: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -161,7 +161,7 @@ def __new__( # type: ignore[misc] *, mean: Array, cov: Array, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(cov, mean) def init_fn(position: ArrayLikeTree): @@ -174,7 +174,7 @@ def step_fn(rng_key: PRNGKey, state): loglikelihood_fn, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def elliptical_proposal( diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 53a38ab1a..62462ae68 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -21,7 +21,7 @@ import blackjax.mcmc.integrators as integrators import blackjax.mcmc.metrics as metrics import blackjax.mcmc.proposal as proposal -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise, pytree_size @@ -253,7 +253,7 @@ class ghmc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -269,7 +269,7 @@ def __new__( # type: ignore[misc] *, divergence_threshold: int = 1000, noise_gn: Callable = lambda _: 0.0, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(noise_gn, divergence_threshold) def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): @@ -286,4 +286,4 @@ def step_fn(rng_key: PRNGKey, state): delta, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 86d90c634..ecdf2394e 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -20,7 +20,7 @@ import blackjax.mcmc.metrics as metrics import blackjax.mcmc.proposal as proposal import blackjax.mcmc.trajectory as trajectory -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.mcmc.trajectory import hmc_energy from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -205,7 +205,7 @@ class hmc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -220,7 +220,7 @@ def __new__( # type: ignore[misc] *, divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(integrator, divergence_threshold) def init_fn(position: ArrayLikeTree): @@ -236,7 +236,7 @@ def step_fn(rng_key: PRNGKey, state): num_integration_steps, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def hmc_proposal( diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 7f94c5d8b..7e76202a0 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -20,7 +20,7 @@ import blackjax.mcmc.diffusions as diffusions import blackjax.mcmc.proposal as proposal -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "mala"] @@ -165,7 +165,7 @@ class mala: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -176,7 +176,7 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, step_size: float, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel() def init_fn(position: ArrayLikeTree): @@ -185,4 +185,4 @@ def init_fn(position: ArrayLikeTree): def step_fn(rng_key: PRNGKey, state): return kernel(rng_key, state, logdensity_fn, step_size) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index 82a90bde0..3072dfab5 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import jax.scipy.linalg as linalg -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import Array, PRNGKey __all__ = ["MarginalState", "MarginalInfo", "init_and_kernel", "mgrad_gaussian"] @@ -173,7 +173,7 @@ class mgrad_gaussian: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -182,7 +182,7 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, covariance: Array, mean: Optional[Array] = None, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: init, kernel = init_and_kernel(logdensity_fn, covariance, mean) def init_fn(position: Array): @@ -195,4 +195,4 @@ def step_fn(rng_key: PRNGKey, state, delta: float): delta, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index d22b159d1..6f3b4fc4b 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -24,7 +24,7 @@ import blackjax.mcmc.proposal as proposal import blackjax.mcmc.termination as termination import blackjax.mcmc.trajectory as trajectory -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["NUTSInfo", "init", "build_kernel", "nuts"] @@ -207,7 +207,7 @@ class nuts: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -223,7 +223,7 @@ def __new__( # type: ignore[misc] max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(integrator, divergence_threshold, max_num_doublings) def init_fn(position: ArrayLikeTree): @@ -238,7 +238,7 @@ def step_fn(rng_key: PRNGKey, state): inverse_mass_matrix, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def iterative_nuts_proposal( diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index 6e2892e66..ae0d9fede 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -19,7 +19,7 @@ import blackjax.mcmc.integrators as integrators import blackjax.mcmc.metrics as metrics -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["PeriodicOrbitalState", "init", "build_kernel", "orbital_hmc"] @@ -259,7 +259,7 @@ class orbital_hmc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -273,7 +273,7 @@ def __new__( # type: ignore[misc] period: int, *, bijection: Callable = integrators.velocity_verlet, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(bijection) def init_fn(position: ArrayLikeTree): @@ -289,7 +289,7 @@ def step_fn(rng_key: PRNGKey, state): period, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def periodic_orbital_proposal( diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 37e6bcc1e..cd999e037 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -59,7 +59,7 @@ import numpy as np from jax import numpy as jnp -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.mcmc import proposal from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise @@ -210,7 +210,7 @@ class additive_step_random_walk: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -227,13 +227,13 @@ def normal_random_walk(cls, logdensity_fn: Callable, sigma): The value of the covariance matrix of the gaussian proposal distribution. Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ return cls(logdensity_fn, normal(sigma)) def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, random_step: Callable - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel() def init_fn(position: ArrayLikeTree): @@ -242,7 +242,7 @@ def init_fn(position: ArrayLikeTree): def step_fn(rng_key: PRNGKey, state): return kernel(rng_key, state, logdensity_fn, random_step) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_irmh() -> Callable: @@ -313,7 +313,7 @@ class irmh: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -324,7 +324,7 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, proposal_distribution: Callable, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel() def init_fn(position: ArrayLikeTree): @@ -333,7 +333,7 @@ def init_fn(position: ArrayLikeTree): def step_fn(rng_key: PRNGKey, state): return kernel(rng_key, state, logdensity_fn, proposal_distribution) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_rmh(): @@ -428,7 +428,7 @@ class rmh: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -439,7 +439,7 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, proposal_generator: Callable[[PRNGKey, ArrayLikeTree], ArrayTree], proposal_logdensity_fn: Optional[Callable[[ArrayLikeTree], ArrayTree]] = None, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel() def init_fn(position: ArrayLikeTree): @@ -454,7 +454,7 @@ def step_fn(rng_key: PRNGKey, state): proposal_logdensity_fn, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) def build_rmh_transition_energy(proposal_logdensity_fn: Optional[Callable]) -> Callable: diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 6b7309119..93a8ea1de 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -19,7 +19,7 @@ import jax import jax.numpy as jnp -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.sgmcmc.diffusions import overdamped_langevin from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -206,7 +206,7 @@ class csgld: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -220,7 +220,7 @@ def __new__( # type: ignore[misc] num_partitions: int = 512, energy_gap: float = 100, min_energy: float = 0, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(num_partitions, energy_gap, min_energy) def init_fn(position: ArrayLikeTree): @@ -246,4 +246,4 @@ def step_fn( temperature, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index 5a403080a..57b0a4ca2 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -15,7 +15,7 @@ from typing import Callable, NamedTuple, Union import blackjax.sgmcmc.diffusions as diffusions -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise @@ -117,7 +117,7 @@ class sgnht: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -129,7 +129,7 @@ def __new__( # type: ignore[misc] grad_estimator: Callable, alpha: float = 0.01, beta: float = 0.0, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(alpha, beta) def init_fn( @@ -150,4 +150,4 @@ def step_fn( rng_key, state, grad_estimator, minibatch, step_size, temperature ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index 632dc3f38..d2b24a9f7 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -20,7 +20,7 @@ import blackjax.smc.ess as ess import blackjax.smc.solver as solver import blackjax.smc.tempered as tempered -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, PRNGKey __all__ = ["build_kernel", "adaptive_tempered_smc"] @@ -130,7 +130,7 @@ class adaptive_tempered_smc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -148,7 +148,7 @@ def __new__( # type: ignore[misc] target_ess: float, root_solver: Callable = solver.dichotomy, num_mcmc_steps: int = 10, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel( logprior_fn, loglikelihood_fn, @@ -170,4 +170,4 @@ def step_fn(rng_key: PRNGKey, state): mcmc_parameters, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 9b52a86a5..f7de5768f 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import blackjax.smc as smc -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.smc.base import SMCState from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -178,7 +178,7 @@ class tempered_smc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -194,7 +194,7 @@ def __new__( # type: ignore[misc] mcmc_parameters: Dict, resampling_fn: Callable, num_mcmc_steps: int = 10, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel( logprior_fn, loglikelihood_fn, @@ -215,4 +215,4 @@ def step_fn(rng_key: PRNGKey, state, lmbda): mcmc_parameters, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index 838921606..9ec7f28ff 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -6,7 +6,7 @@ import optax from jax.flatten_util import ravel_pytree -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree __all__ = ["svgd", "rbf_kernel", "update_median_heuristic"] @@ -139,7 +139,7 @@ class svgd: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ init = staticmethod(init) @@ -164,4 +164,4 @@ def step_fn(state, **grad_params): state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) return update_kernel_parameters(state) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] From 81c1c2c5b0b48a57b6a31777e2409306125753f3 Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Mon, 12 Jun 2023 18:56:25 +0100 Subject: [PATCH 2/2] include stochastic gradient algorithms --- blackjax/sgmcmc/sghmc.py | 8 ++++---- blackjax/sgmcmc/sgld.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 0ca430077..0b1cbfd14 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -17,7 +17,7 @@ import jax import blackjax.sgmcmc.diffusions as diffusions -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise @@ -107,7 +107,7 @@ class sghmc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -120,7 +120,7 @@ def __new__( # type: ignore[misc] num_integration_steps: int = 10, alpha: float = 0.01, beta: float = 0, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(alpha, beta) def init_fn(position: ArrayLikeTree): @@ -143,4 +143,4 @@ def step_fn( temperature, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index afd7086b9..b43f3de89 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -15,7 +15,7 @@ from typing import Callable import blackjax.sgmcmc.diffusions as diffusions -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["init", "build_kernel", "sgld"] @@ -96,7 +96,7 @@ class sgld: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -106,7 +106,7 @@ class sgld: def __new__( # type: ignore[misc] cls, grad_estimator: Callable, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel() def init_fn(position: ArrayLikeTree): @@ -123,4 +123,4 @@ def step_fn( rng_key, state, grad_estimator, minibatch, step_size, temperature ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]