Skip to content

Commit

Permalink
Restrict operators public API
Browse files Browse the repository at this point in the history
  • Loading branch information
alecandido committed Aug 17, 2023
1 parent 7ecb58e commit d5d952c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
7 changes: 2 additions & 5 deletions src/eko/runner/managed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ def solve(theory: TheoryCard, operator: OperatorCard, path: Path):
del eko.parts_matching[recipe]

for ep in operator.evolgrid:
parts_ = operators.retrieve(
operators.parts(ep, eko), eko.parts, eko.parts_matching
)
components = operators.retrieve(ep, eko)
target = Target.from_ep(ep)
eko.operators[target] = operators.join(parts_)
eko.operators[target] = operators.join(components)
# flush the memory
del eko.parts
del eko.parts_matching
del eko.operators[target]
31 changes: 18 additions & 13 deletions src/eko/runner/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import commons, recipes


def retrieve(
def _retrieve(
headers: List[Recipe], parts: Inventory, parts_matching: Inventory
) -> List[Operator]:
"""Retrieve parts to be joined."""
Expand All @@ -24,7 +24,18 @@ def retrieve(
return elements


def dot4(op1: npt.NDArray, op2: npt.NDArray) -> npt.NDArray:
def _parts(ep: EvolutionPoint, eko: EKO) -> List[Recipe]:
"""Determine parts required for the given evolution point operator."""
atlas = commons.atlas(eko.theory_card, eko.operator_card)
return recipes._elements(ep, atlas)


def retrieve(ep: EvolutionPoint, eko: EKO) -> List[Operator]:
"""Retrieve parts required for the given evolution point operator."""
return _retrieve(_parts(ep, eko), eko.parts, eko.parts_matching)


def _dot4(op1: npt.NDArray, op2: npt.NDArray) -> npt.NDArray:
"""Dot product between rank 4 objects.
The product is performed considering them as matrices indexed by pairs, so
Expand All @@ -34,10 +45,10 @@ def dot4(op1: npt.NDArray, op2: npt.NDArray) -> npt.NDArray:
return np.einsum("aibj,bjck->aick", op1, op2)


def dotop(op1: Operator, op2: Operator) -> Operator:
def _dotop(op1: Operator, op2: Operator) -> Operator:
r"""Dot product between two operators.
Essentially a wrapper of :func:`dot4`, applying linear error propagation,
Essentially a wrapper of :func:`_dot4`, applying linear error propagation,
if applicable.
Note
Expand Down Expand Up @@ -67,10 +78,10 @@ def dotop(op1: Operator, op2: Operator) -> Operator:
|da_i| \cdot |b_i| + |a_i| \cdot |db_i| + \mathcal{O}(d^2)
"""
val = dot4(op1.operator, op2.operator)
val = _dot4(op1.operator, op2.operator)

if op1.error is not None and op2.error is not None:
err = dot4(np.abs(op1.operator), np.abs(op2.error)) + dot4(
err = _dot4(np.abs(op1.operator), np.abs(op2.error)) + _dot4(
np.abs(op1.error), np.abs(op2.operator)
)
else:
Expand All @@ -93,10 +104,4 @@ def join(elements: List[Operator]) -> Operator:
consider if reversing the path...
"""
return reduce(dotop, reversed(elements))


def parts(ep: EvolutionPoint, eko: EKO) -> List[Recipe]:
"""Determine parts required for the given evolution point operator."""
atlas = commons.atlas(eko.theory_card, eko.operator_card)
return recipes._elements(ep, atlas)
return reduce(_dotop, reversed(elements))
6 changes: 3 additions & 3 deletions tests/eko/runner/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from eko.io.items import Operator
from eko.io.struct import EKO
from eko.runner.operators import join, retrieve
from eko.runner.operators import _retrieve, join


def test_retrieve(ekoparts: EKO):
evhead, evop = next(iter(ekoparts.parts.cache.items()))
matchhead, matchop = next(iter(ekoparts.parts_matching.cache.items()))

els = retrieve([evhead] * 5, ekoparts.parts, ekoparts.parts_matching)
els = _retrieve([evhead] * 5, ekoparts.parts, ekoparts.parts_matching)
assert len(els) == 5
assert all(isinstance(el, Operator) for el in els)

els = retrieve(
els = _retrieve(
[evhead, matchhead, matchhead], ekoparts.parts, ekoparts.parts_matching
)
assert len(els) == 3
Expand Down

0 comments on commit d5d952c

Please sign in to comment.