-
Notifications
You must be signed in to change notification settings - Fork 108
/
base.py
149 lines (108 loc) · 4.97 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2020- The Blackjax Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Optional
from typing_extensions import Protocol
from .types import ArrayLikeTree, PRNGKey
Position = ArrayLikeTree
State = NamedTuple
Info = NamedTuple
class InitFn(Protocol):
"""A `Callable` used to initialize the kernel state.
Sampling algorithms often need to carry over some informations between
steps, often to avoid computing the same quantity twice. Therefore the
kernels do not operate on the chain positions themselves, but on states that
contain this position and other information.
The `InitFn` returns the state corresponding to a chain position. This state
can then be passed to the `update` function of the `SamplingAlgorithm`.
"""
def __call__(self, position: Position, rng_key: Optional[PRNGKey]) -> State:
"""Initialize the algorithm's state.
Parameters
----------
position
A chain position.
Returns
-------
The kernel state that corresponds to the position.
"""
class UpdateFn(Protocol):
"""A transition kernel used as the `update` of a `SamplingAlgorithms`.
Kernels are pure functions and are idempotent. They necessarily take a
random state `rng_key` and the current kernel state (which contains the
current position) as parameters, return a new state and some information
about the transtion.
Update functions is a simplified yet universal interface with every sampling
algorithm. In essence, what all these algorithms do is take a rng state, a
chain state (possibly a batch of data) and return a new state and some
information about the transition.
"""
def __call__(self, rng_key: PRNGKey, state: State) -> tuple[State, Info]:
"""Update the current state using the sampling algorithm.
Parameters
----------
rng_key:
The random state used by JAX's random numbers generator.
state:
The current kernel state. The kernel state contains the current
chain position as well as other information the kernel needs to
carry over from the previous step.
Returns
-------
A new state, as well as a NamedTuple that contains extra information
about the transition that does not need to be carried over to the next
step.
"""
class SamplingAlgorithm(NamedTuple):
"""A pair of functions that represents a MCMC sampling algorithm.
Blackjax sampling algorithms are implemented as a pair of pure functions: a
kernel, that generates a new sample from the current state, and an
initialization function that creates a kernel state from a chain position.
As they represent Markov kernels, the kernel functions are pure functions
and do not have internal state. To save computation time they also operate
on states which contain the chain state and additional information that
needs to be carried over for the next step.
init:
A pure function which when called with the initial position and the
target density probability function will return the kernel's initial
state.
step:
A pure function that takes a rng key, a state and possibly some
parameters and returns a new state and some information about the
transition.
"""
init: InitFn
step: UpdateFn
class VIAlgorithm(NamedTuple):
"""A pair of functions that represents a Variational Inference algorithm.
Blackjax variational inference algorithms are implemented as a pair of pure
functions: an approximator, which takes a target probability density (and
potentially a guide), and a sampling function that uses the approximation to
draw samples.
approximate
A pure function, which when called with an initial position (and
potentially a guide function) returns a state that allows to build
an approximation to the target probability density function.
sample
A pure function which returns samples from the approximation computed
by `approximate`.
"""
init: Callable
step: Callable
sample: Callable
class RunFn(Protocol):
"""A `Callable` used to run the adaptation procedure."""
def __call__(self, rng_key: PRNGKey, position: ArrayLikeTree):
"""Run the compiled algorithm."""
class AdaptationAlgorithm(NamedTuple):
"""A function that implements an adaptation algorithm."""
run: RunFn