Skip to content

Commit

Permalink
refine selector & local runner
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenyu-ms committed Dec 11, 2023
1 parent 265257b commit 8332640
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 72 deletions.
98 changes: 72 additions & 26 deletions testplan/common/utils/selector.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,92 @@
from abc import ABC, abstractmethod
# decouple represetation from evaluation
# typevar always of kind *, type hint won't work much here

from dataclasses import dataclass
from typing import Union
from typing import Any, Callable, Generic, Set, TypeVar

from typing_extensions import Protocol, Self

# TODO: a la carte
SExpr = Union["And2", "Or2", "Not", "Lit", "_SExpr"]
T = TypeVar("T")
U = TypeVar("U")


class _SExpr(ABC):
@abstractmethod
def eval(self, x) -> bool:
pass
class Functor(Protocol, Generic[T]):
def map(self, f: Callable[[T], U]) -> Self:
# map :: f t -> (t -> u) -> f u
...


@dataclass
class And2(_SExpr):
lexpr: SExpr
rexpr: SExpr
class And2(Generic[T]):
lterm: T
rterm: T

def eval(self, x) -> bool:
return self.lexpr.eval(x) and self.rexpr.eval(x)
def map(self, f):
return And2(f(self.lterm), f(self.rterm))


@dataclass
class Or2(_SExpr):
lexpr: SExpr
rexpr: SExpr
class Or2(Generic[T]):
lterm: T
rterm: T

def eval(self, x) -> bool:
return self.lexpr.eval(x) or self.rexpr.eval(x)
def map(self, f):
return Or2(f(self.lterm), f(self.rterm))


@dataclass
class Not(_SExpr):
expr: SExpr
class Not(Generic[T]):
term: T

def eval(self, x) -> bool:
return not self.expr.eval(x)
def map(self, f):
return Not(f(self.term))


@dataclass
class Lit(_SExpr):
val: str
class Eq(Generic[U]):
val: U

def map(self, f):
return self


Expr = TypeVar("Expr", bound=Functor)


def cata(f: Callable, rep: Expr):
# cata :: (f t -> t) -> f (f (f ...)) -> t
return f(rep.map(lambda x: cata(f, x)))


def eval_on_set(s: Set) -> Callable:
def _(x):
if isinstance(x, Eq):
return {i for i in s if i == x.val}
if isinstance(x, And2):
return x.lterm & x.rterm
if isinstance(x, Or2):
return x.lterm | x.rterm
if isinstance(x, Not):
return s - x.term
raise TypeError(f"unexpected {x}")

return _


def apply_on_set(rep: Expr, s: Set) -> Set:
return cata(eval_on_set(s), rep)


def apply_single(rep: Expr, i: Any) -> bool:
def _(x):
if isinstance(x, Eq):
return x.val == i
if isinstance(x, And2):
return x.lterm and x.rterm
if isinstance(x, Or2):
return x.lterm or x.rterm
if isinstance(x, Not):
return not x.term
raise TypeError(f"unexpected {x}")

def eval(self, x) -> bool:
return self.val == x
return cata(_(i), rep)
20 changes: 3 additions & 17 deletions testplan/report/testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import traceback
from collections import Counter
from enum import Enum
from functools import reduce
from functools import reduce, total_ordering
from typing import Callable, Dict, List, Optional

from typing_extensions import Self
Expand All @@ -59,6 +59,7 @@
from testplan.testing import tagging


@total_ordering
class RuntimeStatus(Enum):
"""
Constants for test runtime status - for interactive mode
Expand Down Expand Up @@ -88,22 +89,7 @@ def precedent(cls, stats):
return min(stats, key=lambda stat: RUNTIMESTATUS_PRECEDENCE[stat])

def __lt__(self, other: Self) -> bool:
lhs, rhs = (
RUNTIMESTATUS_PRECEDENCE[self],
RUNTIMESTATUS_PRECEDENCE[other],
)
if lhs == rhs and self != other:
return NotImplemented
return lhs < rhs

def __le__(self, other: Self) -> bool:
lhs, rhs = (
RUNTIMESTATUS_PRECEDENCE[self],
RUNTIMESTATUS_PRECEDENCE[other],
)
if lhs == rhs and self != other:
return NotImplemented
return lhs <= rhs
return RUNTIMESTATUS_PRECEDENCE[self] < RUNTIMESTATUS_PRECEDENCE[other]

precede = __lt__

Expand Down
13 changes: 4 additions & 9 deletions testplan/runnable/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from queue import Queue
from typing import Any, Dict, Tuple

from testplan.common.utils.selector import SExpr
from testplan.common.utils.selector import Expr, apply_on_set


class InterExecutorMessageT(IntEnum):
Expand All @@ -21,10 +21,6 @@ def make_expected_abort(cls):
return cls(InterExecutorMessageT.EXPECTED_ABORT, None)


# shortcuts
InterExecutorMessage.STOP = InterExecutorMessageT.EXPECTED_ABORT


class QueueChannels:
def __init__(self):
self._qes: Dict[str, Queue] = {}
Expand All @@ -35,7 +31,6 @@ def new_channel(self, name: str) -> Tuple[str, Queue]:
# XXX: in the future
return name, self._qes[name]

def cast(self, selector: SExpr, msg: InterExecutorMessageT):
for k, q in self._qes.items():
if selector.eval(k):
q.put(msg)
def cast(self, selector: Expr, msg: InterExecutorMessageT):
for k in apply_on_set(selector, set(self._qes.keys())):
self._qes[k].put(msg)
42 changes: 23 additions & 19 deletions testplan/runners/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ class LocalRunner(Executor):
def __init__(self, **options) -> None:
super(LocalRunner, self).__init__(**options)
self._uid = "local_runner"
self._to_abort = False
self._curr_runnable_cv = threading.Condition()
self._to_skip_remaining = False
self._curr_runnable_lock = threading.Lock()
self._curr_runnable = None

def execute(self, uid: str) -> TestResult:
"""Execute item implementation."""
# First retrieve the input from its UID.
if self._to_skip_remaining:
# should be disposed immediately
return TestResult()

target = self._input[uid]

# Inspect the input type. Tasks must be materialized before
Expand All @@ -52,16 +56,21 @@ def execute(self, uid: str) -> TestResult:
if not runnable.cfg.parent:
runnable.cfg.parent = self.cfg

with self._curr_runnable_cv:
with self._curr_runnable_lock:
self._curr_runnable = runnable
self._curr_runnable_cv.notify()
result = self._curr_runnable.run()
self._curr_runnable = None
with self._curr_runnable_lock:
self._curr_runnable = None

return result

def _loop(self) -> None:
"""Execution loop implementation for local runner."""
try:
thres = self.cfg.test_breaker_thres.plan_level
except AttributeError:
thres = None

while self.active:
if self.status == self.status.STARTING:
self.status.change(self.status.STARTED)
Expand All @@ -86,19 +95,16 @@ def _loop(self) -> None:
exc,
)
finally:
with self._curr_runnable_cv:
if not self._to_abort:
with self._curr_runnable_lock:
if not self._to_skip_remaining:
# otherwise result from aborted test included
self._results[next_uid] = result
self.ongoing.pop(0)

if (
self.cfg.test_breaker_thres.plan_level
and result.report.status
<= self.cfg.test_breaker_thres.plan_level
):
if thres and result.report.status <= thres:
if self._msg_self_id is not None:
self._msg_out_channels.cast(
S.Not(S.Lit(self._msg_self_id)),
S.Not(S.Eq(self._msg_self_id)),
InterExecutorMessage.make_expected_abort(),
)
self._silently_skip_remaining()
Expand Down Expand Up @@ -143,12 +149,10 @@ def _silently_skip_remaining(self):
self.ongoing = []

def _handle_expected_abort(self, _):
with self._curr_runnable_cv:
if self._curr_runnable is None:
if len(self.ongoing):
self._curr_runnable_cv.wait()
self._curr_runnable.abort()
self._to_abort = True
with self._curr_runnable_lock:
if self._curr_runnable:
self._curr_runnable.abort()
self._to_skip_remaining = True
self._silently_skip_remaining()

def get_current_status_for_debug(self) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion testplan/runners/pools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def task_should_rerun():
):
if self._msg_self_id is not None:
self._msg_out_channels.cast(
S.Not(S.Lit(self._msg_self_id)),
S.Not(S.Eq(self._msg_self_id)),
InterExecutorMessage.make_expected_abort(),
)
self._silently_skip_remaining()
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/testplan/common/utils/test_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from dataclasses import dataclass
from functools import reduce
from typing import Generic, List, TypeVar

import testplan.common.utils.selector as S


def test_basic_op():
assert S.apply_on_set(S.Not(S.Eq("a")), {"a", "b"}) == {"b"}
assert S.apply_on_set(S.Not(S.Eq("a")), {"c", "b"}) == {"b", "c"}
assert S.apply_on_set(
S.Or2(S.Eq("a"), S.Eq("b")), {"a", "b", "c", "d"}
) == {"a", "b"}


def test_ext():

X = TypeVar("X")

@dataclass
class AndN(Generic[X]):
terms: List[X]

def map(self, f):
return AndN(list(map(f, self.terms)))

to_reuse = S.eval_on_set({"a", "b", "c"})

def _(x):
if isinstance(x, AndN):
return reduce(lambda x, y: x.intersection(y), x.terms)
return to_reuse(x)

assert S.cata(
_,
AndN(
[
S.Or2(S.Eq("a"), S.Eq("b")),
S.Or2(S.Eq("b"), S.Eq("c")),
S.Not(S.Eq("a")),
]
),
) == {"b"}
66 changes: 66 additions & 0 deletions tests/unit/testplan/runners/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from time import sleep

import pytest

import testplan.testing.multitest as mt
from testplan.common.utils.selector import Eq
from testplan.runnable import TestRunner as MyTestRunner
from testplan.runnable.messaging import InterExecutorMessage
from testplan.runners.local import LocalRunner


@mt.testsuite
class Suite:
def __init__(self, pre_sleep, post_sleep):
self.pre = pre_sleep
self.post = post_sleep

@mt.testcase
def case_a(self, env, result):
sleep(self.pre)
result.true(False)
sleep(self.post)

@mt.testcase
def case_b(self, env, result):
result.true(False)


MT_NAME = "dummy_mt"


def gen_mt(*suites):
return mt.MultiTest(MT_NAME, suites=suites)


@pytest.mark.parametrize(
"pre_sleep,post_sleep,out_sleep,has_result",
(
(1, 0, 0.5, False),
(0, 1, 0.5, False),
(0, 0, 0.5, True),
),
)
def test_local_simple_abort(pre_sleep, post_sleep, out_sleep, has_result):
par = MyTestRunner(name="in-the-middle-of-unit-tests")
par.add_resource(LocalRunner(), "non-express")
mt = gen_mt(Suite(pre_sleep, post_sleep))
par.add(mt, "non-express")
r: LocalRunner = par.resources["non-express"]
chs = par._exec_channels
r.start()
sleep(out_sleep)
chs.cast(Eq("non-express"), InterExecutorMessage.make_expected_abort())
while r.pending_work():
sleep(0.1)
r.stop()
if has_result:
# we don't have other runners here, casted messages might not get
# processed in time before runner dies
assert MT_NAME in r.results
repo = r.results[MT_NAME].report
assert len(repo) == 1
assert len(repo.entries[0]) == 2
else:
assert r._to_skip_remaining is True
assert len(r.results) == 0
Loading

0 comments on commit 8332640

Please sign in to comment.