Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 28, 2024
1 parent c5c5842 commit 6cafae1
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 15 deletions.
7 changes: 3 additions & 4 deletions src/emcee/backends/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from __future__ import division, print_function

import json
import os
from tempfile import NamedTemporaryFile
import json

import numpy as np

Expand Down Expand Up @@ -202,7 +202,7 @@ def accepted(self):
def random_state(self):
with self.open() as f:
try:
dct = json.loads(f[self.name].attrs['random_state'])
dct = json.loads(f[self.name].attrs["random_state"])
except KeyError:
return None
return dct
Expand Down Expand Up @@ -269,8 +269,7 @@ def save_step(self, state, accepted):
g["accepted"][:] += accepted

g.attrs["random_state"] = json.dumps(
state.random_state,
cls=NumpyEncoder
state.random_state, cls=NumpyEncoder
)

g.attrs["iteration"] = iteration + 1
Expand Down
12 changes: 8 additions & 4 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,18 @@ def random_state(self):
so silently.
"""

def rng_dict(rng):
bg_state = rng.bit_generator.state
ss = rng.bit_generator.seed_seq
ss_dict = dict(
entropy=ss.entropy,
spawn_key=ss.spawn_key,
pool_size=ss.pool_size,
n_children_spawned=ss.n_children_spawned
n_children_spawned=ss.n_children_spawned,
)
return dict(bg_state=bg_state, seed_seq=ss_dict)

return rng_dict(self._random)
# return self._random.bit_generator.state

Expand All @@ -242,13 +244,15 @@ def random_state(self, state):
if it doesn't work. Don't say I didn't warn you...
"""

def _rng_fromdict(d):
bg_state = d['bg_state']
ss = np.random.SeedSequence(**d['seed_seq'])
bg = getattr(np.random, bg_state['bit_generator'])(ss)
bg_state = d["bg_state"]
ss = np.random.SeedSequence(**d["seed_seq"])
bg = getattr(np.random, bg_state["bit_generator"])(ss)
bg.state = bg_state
rng = np.random.Generator(bg)
return rng

try:
self._random = _rng_fromdict(state)
# self._random.bit_generator = state
Expand Down
4 changes: 3 additions & 1 deletion src/emcee/moves/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def get_proposal(self, s, c, random):
diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim)

# Sample a gamma value for each walker following Nelson et al. (2013)
gamma = self.g0 * (1 + self.sigma * random.standard_normal((ns, 1))) # (ns, 1)
gamma = self.g0 * (
1 + self.sigma * random.standard_normal((ns, 1))
) # (ns, 1)

# In this way, sigma is the standard deviation of the distribution of gamma,
# instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006).
Expand Down
8 changes: 6 additions & 2 deletions src/emcee/moves/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def get_factor(self, rng):
return np.exp(rng.uniform(-self._log_factor, self._log_factor))

def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal(
(x0.shape)
)

def __call__(self, x0, rng):
nw, nd = x0.shape
Expand All @@ -106,7 +108,9 @@ def __call__(self, x0, rng):

class _diagonal_proposal(_isotropic_proposal):
def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal(
(x0.shape)
)


class _proposal(_isotropic_proposal):
Expand Down
6 changes: 3 additions & 3 deletions src/emcee/tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_backend(backend, dtype, blobs):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert last1.random_state['bg_state'] == last2.random_state['bg_state']
assert last1.random_state["bg_state"] == last2.random_state["bg_state"]
if blobs:
_custom_allclose(last1.blobs, last2.blobs)
else:
Expand All @@ -137,7 +137,7 @@ def test_backend(backend, dtype, blobs):

@pytest.mark.parametrize("backend,dtype", product(other_backends, dtypes))
def test_reload(backend, dtype):
with (backend() as backend1):
with backend() as backend1:
run_sampler(backend1, dtype=dtype)

# Test the state
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_restart(backend, dtype):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert last1.random_state['bg_state'] == last2.random_state['bg_state']
assert last1.random_state["bg_state"] == last2.random_state["bg_state"]
_custom_allclose(last1.blobs, last2.blobs)

a = sampler1.acceptance_fraction
Expand Down
4 changes: 3 additions & 1 deletion src/emcee/tests/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def run_sampler(
):
rng = np.random.default_rng(seed)
coords = rng.standard_normal((nwalkers, ndim))
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, rng=rng, backend=backend)
sampler = EnsembleSampler(
nwalkers, ndim, normal_log_prob, rng=rng, backend=backend
)
sampler.run_mcmc(
coords,
nsteps,
Expand Down

0 comments on commit 6cafae1

Please sign in to comment.