diff --git a/src/pyk/proof/parallel.py b/src/pyk/proof/parallel.py index ce59eca1a..153e4fed1 100644 --- a/src/pyk/proof/parallel.py +++ b/src/pyk/proof/parallel.py @@ -13,7 +13,7 @@ P = TypeVar('P', bound='Proof') -U = TypeVar('U', bound='Any') +U = TypeVar('U') class Prover(ABC, Generic[P, U]): @@ -83,30 +83,31 @@ def exec(self) -> U: def prove_parallel( - proofs: list[Proof], - # We need a way to map proofs to provers, but for simplicity, I'll assume it as a given - provers: dict[Proof, Prover], + proofs: dict[str, Proof], + provers: dict[str, Prover], ) -> Iterable[Proof]: - pending: dict[Future[Any], Proof] = {} + pending: dict[Future[Any], str] = {} explored: set[ProofStep] = set() - def submit(proof: Proof, pool: Executor) -> None: - prover = provers[proof] + def submit(proof_id: str, pool: Executor) -> None: + proof = proofs[proof_id] + prover = provers[proof_id] for step in prover.steps(proof): # <-- get next steps (represented by e.g. pending nodes, ...) if step in explored: continue explored.add(step) future = pool.submit(step.exec) # <-- schedule steps for execution - pending[future] = proof + pending[future] = proof_id with ProcessPoolExecutor(max_workers=2) as pool: - for proof in proofs: - submit(proof, pool) + for proof_id in proofs.keys(): + submit(proof_id, pool) while pending: - future = list(wait(pending).done)[0] - proof = pending[future] - prover = provers[proof] + future = wait(pending).done.pop() + proof_id = pending[future] + proof = proofs[proof_id] + prover = provers[proof_id] update = future.result() prover.commit(proof, update) # <-- update the proof (can be in-memory, access disk with locking, ...) @@ -121,6 +122,6 @@ def submit(proof: Proof, pool: Executor) -> None: assert len(list(prover.steps(proof))) == 0 break - submit(proof, pool) + submit(proof_id, pool) pending.pop(future) - return proofs + return proofs.values() diff --git a/src/tests/integration/proof/test_parallel_prove.py b/src/tests/integration/proof/test_parallel_prove.py index 7960d2531..e92e35255 100644 --- a/src/tests/integration/proof/test_parallel_prove.py +++ b/src/tests/integration/proof/test_parallel_prove.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING @@ -81,7 +80,6 @@ def commit(self, proof: TreeExploreProof, update: int) -> None: def simple_tree() -> dict[int, set[int]]: - edges: dict[int, set[int]] = {} # 0 # / \ # 1 2 @@ -91,17 +89,7 @@ def simple_tree() -> dict[int, set[int]]: # 5 6 7 # / \ # 8 9 - edges[0] = {1, 2} - edges[1] = set() - edges[2] = {3, 4} - edges[3] = {5, 6} - edges[4] = {7} - edges[5] = set() - edges[6] = set() - edges[7] = {8, 9} - edges[8] = set() - edges[9] = set() - return edges + return {0: {1, 2}, 1: set(), 2: {3, 4}, 3: {5, 6}, 4: {7}, 5: set(), 6: set(), 7: {8, 9}, 8: set(), 9: set()} def test_multiple_provers_fails() -> None: @@ -115,47 +103,10 @@ def test_multiple_provers_fails() -> None: prover2.commit(proof, step.exec()) -def test_steps_read_only() -> None: - def assert_proof_equals(p1: TreeExploreProof, p2: TreeExploreProof) -> None: - assert p1.edges == p2.edges - assert p1.init == p2.init - assert p1.reached == p2.reached - assert p1.target == p2.target - - prover = TreeExploreProver() - proof = TreeExploreProof(0, 9, simple_tree()) - while True: - initial_proof = deepcopy(proof) - steps = prover.steps(proof) - if len(list(steps)) == 0: - break - final_proof = deepcopy(proof) - assert_proof_equals(initial_proof, final_proof) - for step in steps: - prover.commit(proof, step.exec()) - - -def test_commit_after_finished() -> None: - prover = TreeExploreProver() - proof = TreeExploreProof(0, 9, simple_tree()) - results: list[int] = [] - while True: - steps = prover.steps(proof) - if len(list(steps)) == 0: - break - for step in steps: - result = step.exec() - results.append(result) - prover.commit(proof, result) - prover.commit(proof, result) - for result in results: - prover.commit(proof, result) - - def test_parallel_prove() -> None: prover = TreeExploreProver() proof = TreeExploreProof(0, 9, simple_tree()) - results = prove_parallel([proof], {proof: prover}) + results = prove_parallel({'proof1': proof}, {'proof1': prover}) assert len(list(results)) == 1 assert len(list(prover.steps(proof))) == 0 assert list(results)[0].status == ProofStatus.PASSED