Skip to content

Commit

Permalink
Restructure NumPy Random implementation
Browse files Browse the repository at this point in the history
Follow-up to 3c9b258 to restructure NumPy random implementation according to a more
rigid class structure in line with [this
advice](avrae/d20#7 (comment)).
posita committed Sep 24, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 3c9b258 commit bc74905
Showing 5 changed files with 206 additions and 115 deletions.
2 changes: 1 addition & 1 deletion docs/notes.md
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
* Introduces experimental generic [``walk``][dyce.r.walk] function and supporting visitor data structures.
* Uses ``pygraphviz`` to automate class diagram generation.
(See the note on special considerations for regenerating class diagrams in the [hacking quick start](contrib.md#hacking-quick-start).)
* Uses ``numpy`` for RNG, if present.
* Introduces experimental use of ``numpy`` for RNG, if present.
* Migrates to using ``pyproject.toml`` and ``setup.cfg``.

## [0.4.0](https://github.com/posita/dyce/releases/tag/v0.4.0)
4 changes: 2 additions & 2 deletions dyce/h.py
Original file line number Diff line number Diff line change
@@ -57,9 +57,9 @@
overload,
)

from . import rng
from .bt import beartype
from .lifecycle import experimental
from .rng import RNG
from .symmetries import comb, gcd
from .types import (
CachingProtocolMeta,
@@ -1720,7 +1720,7 @@ def roll(self) -> OutcomeT:
Returns a (weighted) random outcome, sorted.
"""
return (
RNG.choices(
rng.RNG.choices(
population=tuple(self.outcomes()),
weights=tuple(self.counts()),
k=1,
138 changes: 92 additions & 46 deletions dyce/rng.py
Original file line number Diff line number Diff line change
@@ -8,76 +8,122 @@

from __future__ import annotations

from abc import ABC
from random import Random
from typing import Any, NewType, Optional, Union
from sys import version_info
from typing import NewType, Sequence, Type, Union

from .bt import beartype

__all__ = ("RNG",)


# ---- Types ---------------------------------------------------------------------------


_RandSeed = Union[int, float, str, bytes, bytearray]
_RandState = NewType("_RandState", Any)
_RandState = NewType("_RandState", object)
_RandSeed = Union[None, int, Sequence[int]]


# ---- Data ----------------------------------------------------------------------------


RNG: Random = Random()
RNG: Random


# ---- Classes -------------------------------------------------------------------------


try:
import numpy.random
from numpy.random import BitGenerator, Generator
from numpy.random import PCG64DXSM, BitGenerator, Generator, default_rng

class NumpyRandom(Random):
_BitGeneratorT = Type[BitGenerator]

class NumPyRandomBase(Random, ABC):
r"""
Defines a [``!#python
Base class for a [``#!python
random.Random``](https://docs.python.org/3/library/random.html#random.Random)
implementation that accepts and uses a [``!#python
implementation that uses a [``#!python
numpy.random.BitGenerator``](https://numpy.org/doc/stable/reference/random/bit_generators/index.html)
under the covers. Motivated by
[avrae/d20#7](https://github.com/avrae/d20/issues/7).
The [initializer][rng.NumPyRandomBase.__init__] takes an optional *seed*, which is
passed to
[``NumPyRandomBase.bit_generator``][dyce.rng.NumPyRandomBase.bit_generator] via
[``NumPyRandomBase.seed``][dyce.rng.NumPyRandomBase.seed] during construction.
"""

def __init__(self, bit_generator: BitGenerator):
self._g = Generator(bit_generator)
bit_generator: _BitGeneratorT
_generator: Generator

if version_info < (3, 11):

@beartype
def __new__(cls, seed: _RandSeed = None):
r"""
Because ``#!python random.Random`` is broken in versions <3.11, ``#!python
random.Random``’s vanilla implementation cannot accept non-hashable
values as the first argument. For example, it will reject lists of
``#!python int``s as *seed*. This implementation of ``#!python __new__``
fixes that.
"""
return super(NumPyRandomBase, cls).__new__(cls)

@beartype
def __init__(self, seed: _RandSeed = None):
# Parent calls self.seed(seed)
super().__init__(seed)

# ---- Overrides ---------------------------------------------------------------

@beartype
def getrandbits(self, k: int) -> int:
# Adapted from the implementation for random.SystemRandom.getrandbits
if k < 0:
raise ValueError("number of bits must be non-negative")

numbytes = (k + 7) // 8 # bits / 8 and rounded up
x = int.from_bytes(self.randbytes(numbytes), "big")

return x >> (numbytes * 8 - k) # trim excess bits

@beartype
# TODO(posita): See <https://github.com/python/typeshed/issues/6063>
def getstate(self) -> _RandState: # type: ignore
return _RandState(self._generator.bit_generator.state)

@beartype
def randbytes(self, n: int) -> bytes:
return self._generator.bytes(n)

@beartype
def random(self) -> float:
return self._g.random()

def seed(self, a: Optional[_RandSeed], version: int = 2) -> None:
if a is not None and not isinstance(a, (int, float, str, bytes, bytearray)):
raise ValueError(f"unrecognized seed type ({type(a)})")

bg_type = type(self._g.bit_generator)

if a is None:
self._g = Generator(bg_type())
else:
# This is somewhat fragile and may not be the best approach. It uses
# `random.Random` to generate its own state from the seed in order to
# maintain compatibility with accepted seed types. (NumPy only accepts
# ints whereas the standard library accepts ints, floats, bytes, etc.).
# That state consists of a 3-tuple: (version: int, internal_state:
# tuple[int], gauss_next: float) at least for for versions through 3 (as
# of this writing). We feed internal_state as the seed for the NumPy
# BitGenerator.
version, internal_state, _ = Random(a).getstate()
self._g = Generator(bg_type(internal_state))

def getstate(self) -> _RandState:
return _RandState(self._g.bit_generator.state)

def setstate(self, state: _RandState) -> None:
self._g.bit_generator.state = state

if hasattr(numpy.random, "PCG64DXSM"):
RNG = NumpyRandom(numpy.random.PCG64DXSM())
elif hasattr(numpy.random, "PCG64"):
RNG = NumpyRandom(numpy.random.PCG64())
elif hasattr(numpy.random, "default_rng"):
RNG = NumpyRandom(numpy.random.default_rng().bit_generator)
return self._generator.random()

@beartype
def seed( # type: ignore
self,
a: _RandSeed,
version: int = 2,
) -> None:
self._generator = default_rng(self.bit_generator(a))

@beartype
def setstate( # type: ignore
self,
# TODO(posita): See <https://github.com/python/typeshed/issues/6063>
state: _RandState,
) -> None:
self._generator.bit_generator.state = state

class PCG64DXSMRandom(NumPyRandomBase):
r"""
A [``NumPyRandomBase``][dyce.rng.NumPyRandomBase] based on
[``numpy.random.PCG64DXSM``](https://numpy.org/doc/stable/reference/random/bit_generators/pcg64dxsm.html#numpy.random.PCG64DXSM).
"""
bit_generator = PCG64DXSM

RNG = PCG64DXSMRandom()
except ImportError:
pass
RNG = Random()
4 changes: 2 additions & 2 deletions tests/test_h.py
Original file line number Diff line number Diff line change
@@ -171,13 +171,13 @@ def test_op_sub_h(self) -> None:
assert d2 - d3 == {
o_type(-2): 1,
o_type(-1): 2,
# See <https://github.com/sympy/sympy/issues/6545>
# TODO(posita): See <https://github.com/sympy/sympy/issues/6545>
o_type(0) + o_type(0): 2,
o_type(1): 1,
}, f"o_type: {o_type}; c_type: {c_type}"
assert d3 - d2 == {
o_type(-1): 1,
# See <https://github.com/sympy/sympy/issues/6545>
# TODO(posita): See <https://github.com/sympy/sympy/issues/6545>
o_type(0) + o_type(0): 2,
o_type(1): 2,
o_type(2): 1,
173 changes: 109 additions & 64 deletions tests/test_rng.py
Original file line number Diff line number Diff line change
@@ -8,93 +8,133 @@

from __future__ import annotations

from decimal import Decimal
from random import Random
from typing import Optional

import pytest

import dyce.rng
from dyce.rng import _RandSeed
from dyce.rng import RNG, _RandSeed

__all__ = ()


# ---- Data ----------------------------------------------------------------------------


SEED_INT_128 = 0x6265656663616665
SEED_FLOAT = float(
Decimal(
"9856940084378475016744131457734599215371411366662962480265638551381775059468656635085733393811201634227995293393551923733235754825282073085472925752147516616452603904"
),
)
SEED_BYTES_128 = b"beefcafe"[::-1]
SEED_INT_192 = 0x646561646265656663616665
SEED_BYTES_192 = b"deadbeefcafe"[::-1]
SEED_INT_64: int = 0x64656164
SEED_INT_128: int = 0x6465616462656566
SEED_INT_192: int = 0x646561646265656663616665
SEED_INTS: _RandSeed = (SEED_INT_64, SEED_INT_128, SEED_INT_192)


# ---- Tests ---------------------------------------------------------------------------


def test_numpy_rng() -> None:
pytest.importorskip("numpy.random", reason="requires numpy")
assert hasattr(dyce.rng, "NumpyRandom")
assert isinstance(dyce.rng.RNG, dyce.rng.NumpyRandom)


def test_numpy_rng_pcg64dxsm() -> None:
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy")

if not hasattr(numpy_random, "PCG64DXSM"):
pytest.skip("requires numpy.random.PCG64DXSM")

rng = dyce.rng.NumpyRandom(numpy_random.PCG64DXSM())
_test_w_seed_helper(rng, SEED_INT_128, 0.7903327469601987)
_test_w_seed_helper(rng, SEED_FLOAT, 0.6018795857570297)
_test_w_seed_helper(rng, SEED_BYTES_128, 0.5339952033746491)
_test_w_seed_helper(rng, SEED_INT_192, 0.9912715409588355)
_test_w_seed_helper(rng, SEED_BYTES_192, 0.13818265573158406)

with pytest.raises(ValueError):
_test_w_seed_helper(rng, object()) # type: ignore


def test_numpy_rng_pcg64() -> None:
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy")
def test_numpy_rng_installed() -> None:
try:
from dyce.rng import PCG64DXSMRandom
except ImportError:
pytest.skip("requires numpy")

if not hasattr(numpy_random, "PCG64"):
pytest.skip("requires numpy.random.PCG64")
assert isinstance(RNG, PCG64DXSMRandom)

rng = dyce.rng.NumpyRandom(numpy_random.PCG64())
_test_w_seed_helper(rng, SEED_INT_128, 0.9794491381144006)
_test_w_seed_helper(rng, SEED_FLOAT, 0.8347478482621317)
_test_w_seed_helper(rng, SEED_BYTES_128, 0.7800090883745199)
_test_w_seed_helper(rng, SEED_INT_192, 0.28018439479392754)
_test_w_seed_helper(rng, SEED_BYTES_192, 0.4814859325412144)

with pytest.raises(ValueError):
_test_w_seed_helper(rng, object()) # type: ignore


def test_numpy_rng_default() -> None:
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy")

if not hasattr(numpy_random, "default_rng"):
pytest.skip("requires numpy.random.default_rng")
def test_numpy_rng() -> None:
try:
from dyce.rng import PCG64DXSMRandom
except ImportError:
pytest.skip("requires numpy")

rng = PCG64DXSMRandom()
seed: _RandSeed
random: float
getrandbits: int
randbytes: bytes

for seed, random, getrandbits, randbytes in (
(
SEED_INT_64,
0.5066807340643421,
0x6CCCD2511ED4B58,
bytes.fromhex("6cccd2511ed4b581"),
),
(
SEED_INT_128,
0.16159916444553268,
0x32CDBF5A16905E2,
bytes.fromhex("32cdbf5a16905e29"),
),
(
SEED_INT_192,
0.09272816060986888,
0xE0D0D43C6108BD1,
bytes.fromhex("e0d0d43c6108bd17"),
),
(
SEED_INTS,
0.32331170065667836,
0x6F230DBC3C8EC45,
bytes.fromhex("6f230dbc3c8ec452"),
),
):
_test_random_w_seed_helper(rng, seed, random)
_test_getrandbits_w_seed_helper(rng, seed, 60, getrandbits)
_test_randbytes_w_seed_helper(rng, seed, randbytes)


def test_standard_rng_installed() -> None:
try:
from dyce.rng import PCG64DXSMRandom # noqa: F401

pytest.skip("requires numpy not be installed")
except ImportError:
pass

assert isinstance(RNG, Random)


def test_standard_rng() -> None:
rng = Random()

for seed in (
SEED_INT_64,
SEED_INT_128,
SEED_INT_192,
SEED_INTS,
):
_test_random_w_seed_helper(rng, seed)


def _test_getrandbits_w_seed_helper(
rng: Random,
seed: _RandSeed,
bits: int,
expected: int,
) -> None:
rng.seed(seed)
state = rng.getstate()
val = rng.getrandbits(bits)
assert val == expected
rng.setstate(state)
assert rng.getrandbits(bits) == val

rng = dyce.rng.NumpyRandom(numpy_random.default_rng().bit_generator)
_test_w_seed_helper(rng, SEED_INT_128)
_test_w_seed_helper(rng, SEED_FLOAT)
_test_w_seed_helper(rng, SEED_BYTES_128)
_test_w_seed_helper(rng, SEED_INT_192)
_test_w_seed_helper(rng, SEED_BYTES_192)

with pytest.raises(ValueError):
_test_w_seed_helper(rng, object()) # type: ignore
def _test_randbytes_w_seed_helper(
rng: Random,
seed: _RandSeed,
expected: bytes,
) -> None:
rng.seed(seed)
state = rng.getstate()
val = rng.randbytes(len(expected))
assert val == expected
rng.setstate(state)
assert rng.randbytes(len(expected)) == val
rng.setstate(state)
assert rng.randbytes(len(expected)) == val


def _test_w_seed_helper(
def _test_random_w_seed_helper(
rng: Random,
seed: _RandSeed,
expected: Optional[float] = None,
@@ -105,7 +145,12 @@ def _test_w_seed_helper(
assert val >= 0.0 and val < 1.0

if expected is not None:
assert expected == val
assert val == expected

assert type(rng)(seed).random() == val

rng.setstate(state)
assert rng.random() == val

rng.seed(seed)
assert rng.random() == val

0 comments on commit bc74905

Please sign in to comment.