From 81c1c2c5b0b48a57b6a31777e2409306125753f3 Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Mon, 12 Jun 2023 18:56:25 +0100 Subject: [PATCH] include stochastic gradient algorithms --- blackjax/sgmcmc/sghmc.py | 8 ++++---- blackjax/sgmcmc/sgld.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 0ca430077..0b1cbfd14 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -17,7 +17,7 @@ import jax import blackjax.sgmcmc.diffusions as diffusions -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise @@ -107,7 +107,7 @@ class sghmc: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -120,7 +120,7 @@ def __new__( # type: ignore[misc] num_integration_steps: int = 10, alpha: float = 0.01, beta: float = 0, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(alpha, beta) def init_fn(position: ArrayLikeTree): @@ -143,4 +143,4 @@ def step_fn( temperature, ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index afd7086b9..b43f3de89 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -15,7 +15,7 @@ from typing import Callable import blackjax.sgmcmc.diffusions as diffusions -from blackjax.base import MCMCSamplingAlgorithm +from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["init", "build_kernel", "sgld"] @@ -96,7 +96,7 @@ class sgld: Returns ------- - A ``MCMCSamplingAlgorithm``. + A ``SamplingAlgorithm``. """ @@ -106,7 +106,7 @@ class sgld: def __new__( # type: ignore[misc] cls, grad_estimator: Callable, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel() def init_fn(position: ArrayLikeTree): @@ -123,4 +123,4 @@ def step_fn( rng_key, state, grad_estimator, minibatch, step_size, temperature ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]