Skip to content

Commit

Permalink
include stochastic gradient algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab committed Jun 12, 2023
1 parent 2acfbf6 commit 81c1c2c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax

import blackjax.sgmcmc.diffusions as diffusions
from blackjax.base import MCMCSamplingAlgorithm
from blackjax.base import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise

Expand Down Expand Up @@ -107,7 +107,7 @@ class sghmc:
Returns
-------
A ``MCMCSamplingAlgorithm``.
A ``SamplingAlgorithm``.
"""

Expand All @@ -120,7 +120,7 @@ def __new__( # type: ignore[misc]
num_integration_steps: int = 10,
alpha: float = 0.01,
beta: float = 0,
) -> MCMCSamplingAlgorithm:
) -> SamplingAlgorithm:
kernel = cls.build_kernel(alpha, beta)

def init_fn(position: ArrayLikeTree):
Expand All @@ -143,4 +143,4 @@ def step_fn(
temperature,
)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
8 changes: 4 additions & 4 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Callable

import blackjax.sgmcmc.diffusions as diffusions
from blackjax.base import MCMCSamplingAlgorithm
from blackjax.base import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["init", "build_kernel", "sgld"]
Expand Down Expand Up @@ -96,7 +96,7 @@ class sgld:
Returns
-------
A ``MCMCSamplingAlgorithm``.
A ``SamplingAlgorithm``.
"""

Expand All @@ -106,7 +106,7 @@ class sgld:
def __new__( # type: ignore[misc]
cls,
grad_estimator: Callable,
) -> MCMCSamplingAlgorithm:
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
Expand All @@ -123,4 +123,4 @@ def step_fn(
rng_key, state, grad_estimator, minibatch, step_size, temperature
)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]

0 comments on commit 81c1c2c

Please sign in to comment.