Skip to content
This repository has been archived by the owner on Apr 25, 2024. It is now read-only.

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nwatson22 committed Nov 7, 2023
1 parent ff9d0c8 commit c8072a4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 66 deletions.
31 changes: 16 additions & 15 deletions src/pyk/proof/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


P = TypeVar('P', bound='Proof')
U = TypeVar('U', bound='Any')
U = TypeVar('U')


class Prover(ABC, Generic[P, U]):
Expand Down Expand Up @@ -83,30 +83,31 @@ def exec(self) -> U:


def prove_parallel(
proofs: list[Proof],
# We need a way to map proofs to provers, but for simplicity, I'll assume it as a given
provers: dict[Proof, Prover],
proofs: dict[str, Proof],
provers: dict[str, Prover],
) -> Iterable[Proof]:
pending: dict[Future[Any], Proof] = {}
pending: dict[Future[Any], str] = {}
explored: set[ProofStep] = set()

def submit(proof: Proof, pool: Executor) -> None:
prover = provers[proof]
def submit(proof_id: str, pool: Executor) -> None:
proof = proofs[proof_id]
prover = provers[proof_id]
for step in prover.steps(proof): # <-- get next steps (represented by e.g. pending nodes, ...)
if step in explored:
continue
explored.add(step)
future = pool.submit(step.exec) # <-- schedule steps for execution
pending[future] = proof
pending[future] = proof_id

with ProcessPoolExecutor(max_workers=2) as pool:
for proof in proofs:
submit(proof, pool)
for proof_id in proofs.keys():
submit(proof_id, pool)

while pending:
future = list(wait(pending).done)[0]
proof = pending[future]
prover = provers[proof]
future = wait(pending).done.pop()
proof_id = pending[future]
proof = proofs[proof_id]
prover = provers[proof_id]
update = future.result()
prover.commit(proof, update) # <-- update the proof (can be in-memory, access disk with locking, ...)

Expand All @@ -121,6 +122,6 @@ def submit(proof: Proof, pool: Executor) -> None:
assert len(list(prover.steps(proof))) == 0
break

submit(proof, pool)
submit(proof_id, pool)
pending.pop(future)
return proofs
return proofs.values()
53 changes: 2 additions & 51 deletions src/tests/integration/proof/test_parallel_prove.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import time
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -81,7 +80,6 @@ def commit(self, proof: TreeExploreProof, update: int) -> None:


def simple_tree() -> dict[int, set[int]]:
edges: dict[int, set[int]] = {}
# 0
# / \
# 1 2
Expand All @@ -91,17 +89,7 @@ def simple_tree() -> dict[int, set[int]]:
# 5 6 7
# / \
# 8 9
edges[0] = {1, 2}
edges[1] = set()
edges[2] = {3, 4}
edges[3] = {5, 6}
edges[4] = {7}
edges[5] = set()
edges[6] = set()
edges[7] = {8, 9}
edges[8] = set()
edges[9] = set()
return edges
return {0: {1, 2}, 1: set(), 2: {3, 4}, 3: {5, 6}, 4: {7}, 5: set(), 6: set(), 7: {8, 9}, 8: set(), 9: set()}


def test_multiple_provers_fails() -> None:
Expand All @@ -115,47 +103,10 @@ def test_multiple_provers_fails() -> None:
prover2.commit(proof, step.exec())


def test_steps_read_only() -> None:
def assert_proof_equals(p1: TreeExploreProof, p2: TreeExploreProof) -> None:
assert p1.edges == p2.edges
assert p1.init == p2.init
assert p1.reached == p2.reached
assert p1.target == p2.target

prover = TreeExploreProver()
proof = TreeExploreProof(0, 9, simple_tree())
while True:
initial_proof = deepcopy(proof)
steps = prover.steps(proof)
if len(list(steps)) == 0:
break
final_proof = deepcopy(proof)
assert_proof_equals(initial_proof, final_proof)
for step in steps:
prover.commit(proof, step.exec())


def test_commit_after_finished() -> None:
prover = TreeExploreProver()
proof = TreeExploreProof(0, 9, simple_tree())
results: list[int] = []
while True:
steps = prover.steps(proof)
if len(list(steps)) == 0:
break
for step in steps:
result = step.exec()
results.append(result)
prover.commit(proof, result)
prover.commit(proof, result)
for result in results:
prover.commit(proof, result)


def test_parallel_prove() -> None:
prover = TreeExploreProver()
proof = TreeExploreProof(0, 9, simple_tree())
results = prove_parallel([proof], {proof: prover})
results = prove_parallel({'proof1': proof}, {'proof1': prover})
assert len(list(results)) == 1
assert len(list(prover.steps(proof))) == 0
assert list(results)[0].status == ProofStatus.PASSED

0 comments on commit c8072a4

Please sign in to comment.