Skip to content

Commit

Permalink
Merge pull request #324 from python-adaptive/average-learner-type-hints
Browse files Browse the repository at this point in the history
AverageLearner type hints
  • Loading branch information
basnijholt authored Aug 27, 2021
2 parents 82245b2 + ce02311 commit e69ab2d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 53 deletions.
41 changes: 25 additions & 16 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from math import sqrt
from typing import Callable, Dict, List, Optional, Tuple

import cloudpickle
import numpy as np

from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Float, Real
from adaptive.utils import cache_latest


Expand Down Expand Up @@ -33,7 +35,13 @@ class AverageLearner(BaseLearner):
Number of evaluated points.
"""

def __init__(self, function, atol=None, rtol=None, min_npoints=2):
def __init__(
self,
function: Callable[[int], Real],
atol: Optional[float] = None,
rtol: Optional[float] = None,
min_npoints: int = 2,
) -> None:
if atol is None and rtol is None:
raise Exception("At least one of `atol` and `rtol` should be set.")
if atol is None:
Expand All @@ -43,24 +51,24 @@ def __init__(self, function, atol=None, rtol=None, min_npoints=2):

self.data = {}
self.pending_points = set()
self.function = function
self.function = function # type: ignore
self.atol = atol
self.rtol = rtol
self.npoints = 0
# Cannot estimate standard deviation with fewer than 2 points.
self.min_npoints = max(min_npoints, 2)
self.sum_f = 0
self.sum_f_sq = 0
self.sum_f: Real = 0.0
self.sum_f_sq: Real = 0.0

@property
def n_requested(self):
def n_requested(self) -> int:
return self.npoints + len(self.pending_points)

def to_numpy(self):
"""Data as NumPy array of size (npoints, 2) with seeds and values."""
return np.array(sorted(self.data.items()))

def ask(self, n, tell_pending=True):
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[Float]]:
points = list(range(self.n_requested, self.n_requested + n))

if any(p in self.data or p in self.pending_points for p in points):
Expand All @@ -77,7 +85,7 @@ def ask(self, n, tell_pending=True):
self.tell_pending(p)
return points, loss_improvements

def tell(self, n, value):
def tell(self, n: int, value: Real) -> None:
if n in self.data:
# The point has already been added before.
return
Expand All @@ -88,16 +96,16 @@ def tell(self, n, value):
self.sum_f_sq += value ** 2
self.npoints += 1

def tell_pending(self, n):
def tell_pending(self, n: int) -> None:
self.pending_points.add(n)

@property
def mean(self):
def mean(self) -> Float:
"""The average of all values in `data`."""
return self.sum_f / self.npoints

@property
def std(self):
def std(self) -> Float:
"""The corrected sample standard deviation of the values
in `data`."""
n = self.npoints
Expand All @@ -110,7 +118,7 @@ def std(self):
return sqrt(numerator / (n - 1))

@cache_latest
def loss(self, real=True, *, n=None):
def loss(self, real: bool = True, *, n=None) -> Float:
if n is None:
n = self.npoints if real else self.n_requested
else:
Expand All @@ -120,11 +128,12 @@ def loss(self, real=True, *, n=None):
standard_error = self.std / sqrt(n)
aloss = standard_error / self.atol
rloss = standard_error / self.rtol
if self.mean != 0:
rloss /= abs(self.mean)
mean = self.mean
if mean != 0:
rloss /= abs(mean)
return max(aloss, rloss)

def _loss_improvement(self, n):
def _loss_improvement(self, n: int) -> Float:
loss = self.loss()
if np.isfinite(loss):
return loss - self.loss(n=self.npoints + n)
Expand All @@ -150,10 +159,10 @@ def plot(self):
vals = hv.Points(vals)
return hv.operation.histogram(vals, num_bins=num_bins, dimension="y")

def _get_data(self):
def _get_data(self) -> Tuple[Dict[int, Real], int, Real, Real]:
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)

def _set_data(self, data):
def _set_data(self, data: Tuple[Dict[int, Real], int, Real, Real]) -> None:
self.data, self.npoints, self.sum_f, self.sum_f_sq = data

def __getstate__(self):
Expand Down
62 changes: 25 additions & 37 deletions adaptive/learner/average_learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,7 @@
from collections import defaultdict
from copy import deepcopy
from math import hypot
from typing import (
Callable,
DefaultDict,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from typing import Callable, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple

import numpy as np
import scipy.stats
Expand All @@ -22,9 +12,9 @@

from adaptive.learner.learner1D import Learner1D, _get_intervals
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Real

number = Union[int, float, np.int_, np.float_]
Point = Tuple[int, number]
Point = Tuple[int, Real]
Points = List[Point]

__all__: List[str] = ["AverageLearner1D"]
Expand All @@ -45,7 +35,7 @@ class AverageLearner1D(Learner1D):
If not provided, then a default is used, which uses the scaled distance
in the x-y plane as the loss. See the notes for more details
of `adaptive.Learner1D` for more details.
delta : float
delta : float, optional, default 0.2
This parameter controls the resampling condition. A point is resampled
if its uncertainty is larger than delta times the smallest neighboring
interval.
Expand Down Expand Up @@ -75,10 +65,10 @@ class AverageLearner1D(Learner1D):

def __init__(
self,
function: Callable[[Tuple[int, number]], number],
bounds: Tuple[number, number],
function: Callable[[Tuple[int, Real]], Real],
bounds: Tuple[Real, Real],
loss_per_interval: Optional[
Callable[[Sequence[number], Sequence[number]], float]
Callable[[Sequence[Real], Sequence[Real]], float]
] = None,
delta: float = 0.2,
alpha: float = 0.005,
Expand Down Expand Up @@ -115,15 +105,15 @@ def __init__(
self._number_samples = SortedDict()
# This set contains the points x that have less than min_samples
# samples or less than a (neighbor_sampling*100)% of their neighbors
self._undersampled_points: Set[number] = set()
self._undersampled_points: Set[Real] = set()
# Contains the error in the estimate of the
# mean at each point x in the form {x0: error(x0), ...}
self.error: ItemSortedDict[number, float] = decreasing_dict()
self.error: ItemSortedDict[Real, float] = decreasing_dict()
#  Distance between two neighboring points in the
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
self._distances: ItemSortedDict[number, float] = decreasing_dict()
self._distances: ItemSortedDict[Real, float] = decreasing_dict()
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
self.rescaled_error: ItemSortedDict[number, float] = decreasing_dict()
self.rescaled_error: ItemSortedDict[Real, float] = decreasing_dict()

@property
def nsamples(self) -> int:
Expand Down Expand Up @@ -165,7 +155,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:

return points, loss_improvements

def _ask_for_more_samples(self, x: number, n: int) -> Tuple[Points, List[float]]:
def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]:
"""When asking for n points, the learner returns n times an existing point
to be resampled, since in general n << min_samples and this point will
need to be resampled many more times"""
Expand Down Expand Up @@ -200,7 +190,7 @@ def tell_pending(self, seed_x: Point) -> None:
self._update_neighbors(x, self.neighbors_combined)
self._update_losses(x, real=False)

def tell(self, seed_x: Point, y: number) -> None:
def tell(self, seed_x: Point, y: Real) -> None:
seed, x = seed_x
if y is None:
raise TypeError(
Expand All @@ -216,7 +206,7 @@ def tell(self, seed_x: Point, y: number) -> None:
self._update_data_structures(seed_x, y, "resampled")
self.pending_points.discard(seed_x)

def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
def _update_rescaled_error_in_mean(self, x: Real, point_type: str) -> None:
"""Updates ``self.rescaled_error``.
Parameters
Expand Down Expand Up @@ -253,17 +243,15 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
norm = min(d_left, d_right)
self.rescaled_error[x] = self.error[x] / norm

def _update_data(self, x: number, y: number, point_type: str) -> None:
def _update_data(self, x: Real, y: Real, point_type: str) -> None:
if point_type == "new":
self.data[x] = y
elif point_type == "resampled":
n = len(self._data_samples[x])
new_average = self.data[x] * n / (n + 1) + y / (n + 1)
self.data[x] = new_average

def _update_data_structures(
self, seed_x: Point, y: number, point_type: str
) -> None:
def _update_data_structures(self, seed_x: Point, y: Real, point_type: str) -> None:
seed, x = seed_x
if point_type == "new":
self._data_samples[x] = {seed: y}
Expand Down Expand Up @@ -331,15 +319,15 @@ def _update_data_structures(
self._update_interpolated_loss_in_interval(*interval)
self._oldscale = deepcopy(self._scale)

def _update_distances(self, x: number) -> None:
def _update_distances(self, x: Real) -> None:
x_left, x_right = self.neighbors[x]
y = self.data[x]
if x_left is not None:
self._distances[x_left] = hypot((x - x_left), (y - self.data[x_left]))
if x_right is not None:
self._distances[x] = hypot((x_right - x), (self.data[x_right] - y))

def _update_losses_resampling(self, x: number, real=True) -> None:
def _update_losses_resampling(self, x: Real, real=True) -> None:
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
# (x_left, x_right) are the "real" neighbors of 'x'.
x_left, x_right = self._find_neighbors(x, self.neighbors)
Expand Down Expand Up @@ -368,12 +356,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
if (b is not None) and right_loss_is_unknown:
self.losses_combined[x, b] = float("inf")

def _calc_error_in_mean(self, ys: Sequence[number], y_avg: number, n: int) -> float:
def _calc_error_in_mean(self, ys: Sequence[Real], y_avg: Real, n: int) -> float:
variance_in_mean = sum((y - y_avg) ** 2 for y in ys) / (n - 1)
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
return t_student * (variance_in_mean / n) ** 0.5

def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
def tell_many(self, xs: Points, ys: Sequence[Real]) -> None:
# Check that all x are within the bounds
# TODO: remove this requirement, all other learners add the data
# but ignore it going forward.
Expand All @@ -384,7 +372,7 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
)

# Create a mapping of points to a list of samples
mapping: DefaultDict[number, DefaultDict[int, number]] = defaultdict(
mapping: DefaultDict[Real, DefaultDict[int, Real]] = defaultdict(
lambda: defaultdict(dict)
)
for (seed, x), y in zip(xs, ys):
Expand All @@ -400,14 +388,14 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
# simultaneously, before we move on to a new x
self.tell_many_at_point(x, seed_y_mapping)

def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> None:
def tell_many_at_point(self, x: Real, seed_y_mapping: Dict[int, Real]) -> None:
"""Tell the learner about many samples at a certain location x.
Parameters
----------
x : float
Value from the function domain.
seed_y_mapping : Dict[int, number]
seed_y_mapping : Dict[int, Real]
Dictionary of ``seed`` -> ``y`` at ``x``.
"""
# Check x is within the bounds
Expand Down Expand Up @@ -456,10 +444,10 @@ def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> No
self._update_interpolated_loss_in_interval(*interval)
self._oldscale = deepcopy(self._scale)

def _get_data(self) -> SortedDict[number, number]:
def _get_data(self) -> SortedDict[Real, Real]:
return self._data_samples

def _set_data(self, data: SortedDict[number, number]) -> None:
def _set_data(self, data: SortedDict[Real, Real]) -> None:
if data:
for x, samples in data.items():
self.tell_many_at_point(x, samples)
Expand Down
7 changes: 7 additions & 0 deletions adaptive/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Union

import numpy as np

Float = Union[float, np.float_]
Int = Union[int, np.int_]
Real = Union[Float, Int]
3 changes: 3 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ exclude = .git, .tox, __pycache__, dist

[isort]
profile=black

[mypy]
ignore_missing_imports = True

0 comments on commit e69ab2d

Please sign in to comment.