Skip to content

Commit

Permalink
refactor: add OnlineVectorPicker new loss functions support
Browse files Browse the repository at this point in the history
  • Loading branch information
vsheg committed Jul 5, 2024
1 parent de428b6 commit a117770
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 39 deletions.
50 changes: 13 additions & 37 deletions moll/pick/_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


from collections.abc import Hashable, Iterable
from functools import partial

import jax.numpy as jnp
import numpy as np
Expand All @@ -22,7 +23,7 @@
SimilarityFnCallable,
SimilarityFnLiteral,
)
from ..utils import get_function_from_literal
from ..utils import get_function_from_literal, hasarg
from ._online_add import update_vectors


Expand All @@ -38,8 +39,8 @@ def __init__(
*,
dist_fn: DistanceFnCallable | DistanceFnLiteral = "euclidean",
sim_fn: SimilarityFnCallable | SimilarityFnLiteral = "identity",
loss_fn: PotentialFnCallable | PotentialFnLiteral = "hyperbolic",
p: float | int = 1,
loss_fn: PotentialFnCallable | PotentialFnLiteral = "power",
p: float | int = -1,
k_neighbors: int | float = 5, # TODO: add heuristic for better default
threshold: float = -jnp.inf,
dtype: DTypeLike | None = None,
Expand All @@ -51,14 +52,18 @@ def __init__(
self.capacity: int = capacity

self.dist_fn: DistanceFnCallable = get_function_from_literal(
dist_fn, module="moll.measures"
dist_fn, module="moll.measures._distance"
)
self.sim_fn: SimilarityFnCallable = get_function_from_literal(
sim_fn, module="moll.measures._similarity"
)
self.sim_fn: SimilarityFnCallable = self._init_sim_fn(sim_fn)

self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity)

self.p: float | int = p
self.loss_fn: PotentialFnCallable = self._init_loss_fn(loss_fn, self.p)
loss_fn = get_function_from_literal(loss_fn, module="moll.measures._loss")
loss_fn = partial(loss_fn, p=p) if hasarg(loss_fn, "p") else loss_fn
self.loss_fn: PotentialFnCallable = loss_fn

self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity)

self.threshold: float = threshold

Expand Down Expand Up @@ -92,35 +97,6 @@ def _init_k_neighbors(self, k_neighbors: int | float, capacity: int) -> int:

return k_neighbors

def _init_sim_fn(
self, sim_fn: SimilarityFnLiteral | SimilarityFnCallable
) -> SimilarityFnCallable:
match sim_fn:
case "identity":
return lambda x: x
return sim_fn

def _init_loss_fn(
self, loss_fn: PotentialFnLiteral | PotentialFnCallable, p: float
) -> PotentialFnCallable:
match loss_fn:
case "hyperbolic":
return lambda d: jnp.where(d > 0, jnp.power(d, -p), jnp.inf)
case "exp":
return lambda d: jnp.exp(-p * d)
case "lj":
sigma: float = p * 0.5 ** (1 / 6)
return lambda d: (
jnp.where(
d > 0,
jnp.power(sigma / d, 12.0) - jnp.power(sigma / d, 6.0),
jnp.inf,
)
)
case "log":
return lambda d: jnp.where(d > 0, -jnp.log(p * d), jnp.inf)
return loss_fn

def _init_data(self, vector: Array, label=None):
"""Initialize the picker with the first vector."""
dim = vector.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion moll/pick/tests/test_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def picker_similarity_fn(request):
return OnlineVectorPicker(
capacity=5,
dist_fn=similarity_fn,
loss_fn="exp", # exp potential is used to treat negative similarities
loss_fn="exponential", # exp potential is used to treat negative similarities
)


Expand Down
8 changes: 7 additions & 1 deletion moll/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
SimilarityFnLiteral = Literal["identity"]
SimilarityFnCallable: TypeAlias = Callable[[Array], ArrayLike]

PotentialFnLiteral = Literal["hyperbolic", "exp", "lj", "log"]
# TODO: maybe use abbreviations?
PotentialFnLiteral = Literal[
"power",
"exponential",
"lennard_jones",
"logarithmic",
]
PotentialFnCallable: TypeAlias = Callable[[float], ArrayLike]


Expand Down

0 comments on commit a117770

Please sign in to comment.