From d5d952c4970e03fb8c1b194d06e3805170e0f308 Mon Sep 17 00:00:00 2001 From: Alessandro Candido Date: Thu, 17 Aug 2023 16:41:27 +0200 Subject: [PATCH] Restrict operators public API --- src/eko/runner/managed.py | 7 ++----- src/eko/runner/operators.py | 31 +++++++++++++++++------------- tests/eko/runner/test_operators.py | 6 +++--- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/eko/runner/managed.py b/src/eko/runner/managed.py index aacd249cf..2c3d2c313 100644 --- a/src/eko/runner/managed.py +++ b/src/eko/runner/managed.py @@ -36,12 +36,9 @@ def solve(theory: TheoryCard, operator: OperatorCard, path: Path): del eko.parts_matching[recipe] for ep in operator.evolgrid: - parts_ = operators.retrieve( - operators.parts(ep, eko), eko.parts, eko.parts_matching - ) + components = operators.retrieve(ep, eko) target = Target.from_ep(ep) - eko.operators[target] = operators.join(parts_) + eko.operators[target] = operators.join(components) # flush the memory del eko.parts - del eko.parts_matching del eko.operators[target] diff --git a/src/eko/runner/operators.py b/src/eko/runner/operators.py index 2ccbd1013..702c6479c 100644 --- a/src/eko/runner/operators.py +++ b/src/eko/runner/operators.py @@ -12,7 +12,7 @@ from . import commons, recipes -def retrieve( +def _retrieve( headers: List[Recipe], parts: Inventory, parts_matching: Inventory ) -> List[Operator]: """Retrieve parts to be joined.""" @@ -24,7 +24,18 @@ def retrieve( return elements -def dot4(op1: npt.NDArray, op2: npt.NDArray) -> npt.NDArray: +def _parts(ep: EvolutionPoint, eko: EKO) -> List[Recipe]: + """Determine parts required for the given evolution point operator.""" + atlas = commons.atlas(eko.theory_card, eko.operator_card) + return recipes._elements(ep, atlas) + + +def retrieve(ep: EvolutionPoint, eko: EKO) -> List[Operator]: + """Retrieve parts required for the given evolution point operator.""" + return _retrieve(_parts(ep, eko), eko.parts, eko.parts_matching) + + +def _dot4(op1: npt.NDArray, op2: npt.NDArray) -> npt.NDArray: """Dot product between rank 4 objects. The product is performed considering them as matrices indexed by pairs, so @@ -34,10 +45,10 @@ def dot4(op1: npt.NDArray, op2: npt.NDArray) -> npt.NDArray: return np.einsum("aibj,bjck->aick", op1, op2) -def dotop(op1: Operator, op2: Operator) -> Operator: +def _dotop(op1: Operator, op2: Operator) -> Operator: r"""Dot product between two operators. - Essentially a wrapper of :func:`dot4`, applying linear error propagation, + Essentially a wrapper of :func:`_dot4`, applying linear error propagation, if applicable. Note @@ -67,10 +78,10 @@ def dotop(op1: Operator, op2: Operator) -> Operator: |da_i| \cdot |b_i| + |a_i| \cdot |db_i| + \mathcal{O}(d^2) """ - val = dot4(op1.operator, op2.operator) + val = _dot4(op1.operator, op2.operator) if op1.error is not None and op2.error is not None: - err = dot4(np.abs(op1.operator), np.abs(op2.error)) + dot4( + err = _dot4(np.abs(op1.operator), np.abs(op2.error)) + _dot4( np.abs(op1.error), np.abs(op2.operator) ) else: @@ -93,10 +104,4 @@ def join(elements: List[Operator]) -> Operator: consider if reversing the path... """ - return reduce(dotop, reversed(elements)) - - -def parts(ep: EvolutionPoint, eko: EKO) -> List[Recipe]: - """Determine parts required for the given evolution point operator.""" - atlas = commons.atlas(eko.theory_card, eko.operator_card) - return recipes._elements(ep, atlas) + return reduce(_dotop, reversed(elements)) diff --git a/tests/eko/runner/test_operators.py b/tests/eko/runner/test_operators.py index 8721fca99..77811cf3a 100644 --- a/tests/eko/runner/test_operators.py +++ b/tests/eko/runner/test_operators.py @@ -2,18 +2,18 @@ from eko.io.items import Operator from eko.io.struct import EKO -from eko.runner.operators import join, retrieve +from eko.runner.operators import _retrieve, join def test_retrieve(ekoparts: EKO): evhead, evop = next(iter(ekoparts.parts.cache.items())) matchhead, matchop = next(iter(ekoparts.parts_matching.cache.items())) - els = retrieve([evhead] * 5, ekoparts.parts, ekoparts.parts_matching) + els = _retrieve([evhead] * 5, ekoparts.parts, ekoparts.parts_matching) assert len(els) == 5 assert all(isinstance(el, Operator) for el in els) - els = retrieve( + els = _retrieve( [evhead, matchhead, matchhead], ekoparts.parts, ekoparts.parts_matching ) assert len(els) == 3