Skip to content

Commit

Permalink
improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Aug 28, 2021
1 parent 9bfbeb8 commit 3d359c3
Showing 1 changed file with 87 additions and 65 deletions.
152 changes: 87 additions & 65 deletions adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
import collections.abc
import itertools
import math
import numbers
from copy import deepcopy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union

import cloudpickle
import numpy as np
Expand All @@ -25,14 +13,42 @@
from adaptive.learner.learnerND import volume
from adaptive.learner.triangulation import simplex_volume_in_embedding
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Float
from adaptive.types import Float, Int, Real
from adaptive.utils import cache_latest

Point = Tuple[Float, Float]
# -- types --

# Commonly used types
Interval = Union[Tuple[float, float], Tuple[float, float, int]]
NeighborsType = Dict[float, List[Optional[float]]]

# Types for loss_per_interval functions
NoneFloat = Union[Float, None]
NoneArray = Union[np.ndarray, None]
XsType0 = Tuple[Float, Float]
YsType0 = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
XsType1 = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
YsType1 = Union[
Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat],
Tuple[NoneArray, NoneArray, NoneArray, NoneArray],
]
XsTypeN = Tuple[NoneFloat, ...]
YsTypeN = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]


__all__ = [
"uniform_loss",
"default_loss",
"abs_min_log_loss",
"triangle_loss",
"resolution_loss_function",
"curvature_loss_function",
"Learner1D",
]


@uses_nth_neighbors(0)
def uniform_loss(xs: Point, ys: Any) -> Float:
def uniform_loss(xs: XsType0, ys: YsType0) -> Float:
"""Loss function that samples the domain uniformly.
Works with `~adaptive.Learner1D` only.
Expand All @@ -52,10 +68,7 @@ def uniform_loss(xs: Point, ys: Any) -> Float:


@uses_nth_neighbors(0)
def default_loss(
xs: Point,
ys: Union[Tuple[Iterable[Float], Iterable[Float]], Point],
) -> float:
def default_loss(xs: XsType0, ys: YsType0) -> Float:
"""Calculate loss on a single interval.
Currently returns the rescaled length of the interval. If one of the
Expand All @@ -64,28 +77,23 @@ def default_loss(
"""
dx = xs[1] - xs[0]
if isinstance(ys[0], collections.abc.Iterable):
dy_vec = [abs(a - b) for a, b in zip(*ys)]
dy_vec = np.array([abs(a - b) for a, b in zip(*ys)])
return np.hypot(dx, dy_vec).max()
else:
dy = ys[1] - ys[0]
return np.hypot(dx, dy)


@uses_nth_neighbors(0)
def abs_min_log_loss(xs, ys):
def abs_min_log_loss(xs: XsType0, ys: YsType0) -> Float:
"""Calculate loss of a single interval that prioritizes the absolute minimum."""
ys = [np.log(np.abs(y).min()) for y in ys]
ys = tuple(np.log(np.abs(y).min()) for y in ys)
return default_loss(xs, ys)


@uses_nth_neighbors(1)
def triangle_loss(
xs: Sequence[Optional[Float]],
ys: Union[
Iterable[Optional[Float]],
Iterable[Union[Iterable[Float], None]],
],
) -> float:
def triangle_loss(xs: XsType1, ys: YsType1) -> Float:
assert len(xs) == 4
xs = [x for x in xs if x is not None]
ys = [y for y in ys if y is not None]

Expand All @@ -102,7 +110,9 @@ def triangle_loss(
return sum(vol(pts[i : i + 3]) for i in range(N)) / N


def resolution_loss_function(min_length=0, max_length=1):
def resolution_loss_function(
min_length: Real = 0, max_length: Real = 1
) -> Callable[[XsType0, YsType0], Float]:
"""Loss function that is similar to the `default_loss` function, but you
can set the maximum and minimum size of an interval.
Expand All @@ -125,7 +135,7 @@ def resolution_loss_function(min_length=0, max_length=1):
"""

@uses_nth_neighbors(0)
def resolution_loss(xs, ys):
def resolution_loss(xs: XsType0, ys: YsType0) -> Float:
loss = uniform_loss(xs, ys)
if loss < min_length:
# Return zero such that this interval won't be chosen again
Expand All @@ -140,11 +150,11 @@ def resolution_loss(xs, ys):


def curvature_loss_function(
area_factor: float = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
) -> Callable:
area_factor: Real = 1, euclid_factor: Real = 0.02, horizontal_factor: Real = 0.02
) -> Callable[[XsType1, YsType1], Float]:
# XXX: add a doc-string
@uses_nth_neighbors(1)
def curvature_loss(xs, ys):
def curvature_loss(xs: XsType1, ys: YsType1) -> Float:
xs_middle = xs[1:3]
ys_middle = ys[1:3]

Expand All @@ -160,7 +170,7 @@ def curvature_loss(xs, ys):
return curvature_loss


def linspace(x_left: float, x_right: float, n: int) -> List[float]:
def linspace(x_left: Real, x_right: Real, n: Int) -> List[Float]:
"""This is equivalent to
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
but it is 15-30 times faster for small 'n'."""
Expand All @@ -172,7 +182,7 @@ def linspace(x_left: float, x_right: float, n: int) -> List[float]:
return [x_left + step * i for i in range(1, n)]


def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
def _get_neighbors_from_array(xs: np.ndarray) -> NeighborsType:
xs = np.sort(xs)
xs_left = np.roll(xs, 1).tolist()
xs_right = np.roll(xs, -1).tolist()
Expand All @@ -182,7 +192,9 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
return SortedDict(neighbors)


def _get_intervals(x: float, neighbors: SortedDict, nth_neighbors: int) -> Any:
def _get_intervals(
x: float, neighbors: NeighborsType, nth_neighbors: int
) -> List[Tuple[float, float]]:
nn = nth_neighbors
i = neighbors.index(x)
start = max(0, i - nn - 1)
Expand Down Expand Up @@ -237,10 +249,10 @@ class Learner1D(BaseLearner):

def __init__(
self,
function: Callable,
bounds: Tuple[float, float],
loss_per_interval: Optional[Callable] = None,
) -> None:
function: Callable[[Real], Union[Float, np.ndarray]],
bounds: Tuple[Real, Real],
loss_per_interval: Optional[Callable[[XsTypeN, YsTypeN], Float]] = None,
):
self.function = function # type: ignore

if hasattr(loss_per_interval, "nth_neighbors"):
Expand All @@ -255,13 +267,13 @@ def __init__(
# the learners behavior in the tests.
self._recompute_losses_factor = 2

self.data = {}
self.pending_points = set()
self.data: Dict[Real, Real] = {}
self.pending_points: Set[Real] = set()

# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
# properties.
self.neighbors = SortedDict()
self.neighbors_combined = SortedDict()
self.neighbors: NeighborsType = SortedDict()
self.neighbors_combined: NeighborsType = SortedDict()

# Bounding box [[minx, maxx], [miny, maxy]].
self._bbox = [list(bounds), [np.inf, -np.inf]]
Expand Down Expand Up @@ -319,14 +331,14 @@ def loss(self, real: bool = True) -> float:
max_interval, max_loss = losses.peekitem(0)
return max_loss

def _scale_x(self, x: Optional[float]) -> Optional[float]:
def _scale_x(self, x: Optional[Float]) -> Optional[Float]:
if x is None:
return None
return x / self._scale[0]

def _scale_y(
self, y: Optional[Union[Float, np.ndarray]]
) -> Optional[Union[Float, np.ndarray]]:
self, y: Union[Float, np.ndarray, None]
) -> Union[Float, np.ndarray, None]:
if y is None:
return None
y_scale = self._scale[1] or 1
Expand Down Expand Up @@ -418,7 +430,7 @@ def _update_losses(self, x: float, real: bool = True) -> None:
self.losses_combined[x, b] = float("inf")

@staticmethod
def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
def _find_neighbors(x: float, neighbors: NeighborsType) -> Any:
if x in neighbors:
return neighbors[x]
pos = neighbors.bisect_left(x)
Expand All @@ -427,7 +439,7 @@ def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
x_right = keys[pos] if pos != len(neighbors) else None
return x_left, x_right

def _update_neighbors(self, x: float, neighbors: SortedDict) -> None:
def _update_neighbors(self, x: float, neighbors: NeighborsType) -> None:
if x not in neighbors: # The point is new
x_left, x_right = self._find_neighbors(x, neighbors)
neighbors[x] = [x_left, x_right]
Expand Down Expand Up @@ -461,9 +473,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
self._bbox[1][1] = max(self._bbox[1][1], y)
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]

def tell(
self, x: float, y: Union[Float, Sequence[numbers.Number], np.ndarray]
) -> None:
def tell(self, x: float, y: Union[Float, Sequence[Float], np.ndarray]) -> None:
if x in self.data:
# The point is already evaluated before
return
Expand Down Expand Up @@ -506,7 +516,17 @@ def tell_pending(self, x: float) -> None:
self._update_neighbors(x, self.neighbors_combined)
self._update_losses(x, real=False)

def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None:
def tell_many(
self,
xs: Sequence[Float],
ys: Union[
Sequence[Float],
Sequence[Sequence[Float]],
Sequence[np.ndarray],
],
*,
force: bool = False
) -> None:
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
# Only run this more efficient method if there are
# at least 2 points and the amount of points added are
Expand All @@ -526,8 +546,8 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
points_combined = np.hstack([points_pending, points])

# Generate neighbors
self.neighbors = _get_neighbors_from_list(points)
self.neighbors_combined = _get_neighbors_from_list(points_combined)
self.neighbors = _get_neighbors_from_array(points)
self.neighbors_combined = _get_neighbors_from_array(points_combined)

# Update scale
self._bbox[0] = [points_combined.min(), points_combined.max()]
Expand Down Expand Up @@ -574,7 +594,7 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
# have an inf loss.
self._update_interpolated_loss_in_interval(*ival)

def ask(self, n: int, tell_pending: bool = True) -> Any:
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[float]]:
"""Return 'n' points that are expected to maximally reduce the loss."""
points, loss_improvements = self._ask_points_without_adding(n)

Expand All @@ -584,7 +604,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:

return points, loss_improvements

def _ask_points_without_adding(self, n: int) -> Any:
def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
"""Return 'n' points that are expected to maximally reduce the loss.
Without altering the state of the learner"""
# Find out how to divide the n points over the intervals
Expand Down Expand Up @@ -648,7 +668,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
quals[(*xs, n + 1)] = loss_qual * n / (n + 1)

points = list(
itertools.chain.from_iterable(linspace(a, b, n) for ((a, b), n) in quals)
itertools.chain.from_iterable(linspace(*ival, n) for (*ival, n) in quals)
)

loss_improvements = list(
Expand All @@ -663,11 +683,13 @@ def _ask_points_without_adding(self, n: int) -> Any:

return points, loss_improvements

def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
def _loss(
self, mapping: Dict[Interval, float], ival: Interval
) -> Tuple[float, Interval]:
loss = mapping[ival]
return finite_loss(ival, loss, self._scale[0])

def plot(self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"):
def plot(self, *, scatter_or_line: str = "scatter"):
"""Returns a plot of the evaluated data.
Parameters
Expand Down Expand Up @@ -734,7 +756,7 @@ def __setstate__(self, state):
self.losses_combined.update(losses_combined)


def loss_manager(x_scale: float) -> ItemSortedDict:
def loss_manager(x_scale: float) -> Dict[Interval, float]:
def sort_key(ival, loss):
loss, ival = finite_loss(ival, loss, x_scale)
return -loss, ival
Expand All @@ -743,8 +765,8 @@ def sort_key(ival, loss):
return sorted_dict


def finite_loss(ival: Any, loss: float, x_scale: float) -> Any:
"""Get the socalled finite_loss of an interval in order to be able to
def finite_loss(ival: Interval, loss: float, x_scale: float) -> Tuple[float, Interval]:
"""Get the so-called finite_loss of an interval in order to be able to
sort intervals that have infinite loss."""
# If the loss is infinite we return the
# distance between the two points.
Expand Down

0 comments on commit 3d359c3

Please sign in to comment.