diff --git a/hipaasat/solvers.py b/hipaasat/solvers.py new file mode 100644 index 0000000..9944cbc --- /dev/null +++ b/hipaasat/solvers.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +import sys + +from .cnf import check_consistency, Clause, CNF, Literal, simplify + +from typing import Optional, Tuple + +class SATSolver(ABC): + """Abstract class that solves boolean satisfyibility problems.""" + + @abstractmethod + def solve(self, cnf: CNF) -> Tuple[bool, Optional[CNF]]: + ... + +class DPLL(SATSolver): + """Davis–Putnam–Logemann–Loveland (DPLL) boolean satisfyiblity solver.""" + + def solve(self, cnf: CNF) -> Tuple[bool, Optional[CNF]]: + consistent = check_consistency(cnf) + if consistent is not None: + return consistent, cnf + simplified_cnf = simplify(cnf) + if simplified_cnf is None: + return False, cnf + consistent = check_consistency(simplified_cnf) + if consistent is not None: + return consistent, cnf + + # Grab unassigned variable from shortest clause + shortest_clause = self._find_shortest_nonempty_clause(cnf) + assert shortest_clause + lit = self._get_unassigned_literal(shortest_clause) + + solved, result_cnf = self.solve(cnf.assign(lit.name, True)) + if not solved: + solved, result_cnf = self.solve(cnf.assign(lit.name, False)) + + return solved, result_cnf + + def _find_shortest_nonempty_clause(self, cnf: CNF) -> Optional[Clause]: + shortest_clause = None + shortest_clause_length = sys.maxsize # hopefully there's not a clause with 2^32 or 2^64 literals... + for c in cnf: + if c.unassigned_literal_count() < shortest_clause_length and c.unassigned_literal_count() > 0: + shortest_clause = c + shortest_clause_length = c.unassigned_literal_count() + return shortest_clause + + def _get_unassigned_literal(self, clause: Clause) -> Literal: + literals = clause.get_unassigned_literals() + return literals[0] diff --git a/tests/test_dpll.py b/tests/test_dpll.py new file mode 100644 index 0000000..49b1c11 --- /dev/null +++ b/tests/test_dpll.py @@ -0,0 +1,151 @@ +import pytest + +from hipaasat.cnf import CNF, Clause, ClauseType, Literal +from hipaasat.solvers import DPLL + +def test_single_literal(): + solver = DPLL() + cnf = CNF([ + Clause(ClauseType.OR, [ + Literal("test") + ]), + ]) + solved, result_cnf = solver.solve(cnf) + assert solved + assert result_cnf.assigned_literal_count() == result_cnf.unique_literal_count() + +def test_single_literal_multiple_clauses(): + solver = DPLL() + cnf = CNF([ + Clause(ClauseType.OR, [ + Literal("test") + ]), + Clause(ClauseType.OR, [ + Literal("test") + ]), + Clause(ClauseType.OR, [ + Literal("test") + ]), + ]) + solved, result_cnf = solver.solve(cnf) + assert solved + assert result_cnf.assigned_literal_count() == result_cnf.unique_literal_count() + +def test_single_literal_multiple_clauses_unsolvable(): + solver = DPLL() + cnf = CNF([ + Clause(ClauseType.OR, [ + Literal("test") + ]), + Clause(ClauseType.OR, [ + Literal("test", negated=True) + ]), + Clause(ClauseType.OR, [ + Literal("test") + ]), + ]) + solved, _ = solver.solve(cnf) + assert not solved + +def test_multiple_literals_single_clause(): + solver = DPLL() + cnf = CNF([ + Clause(ClauseType.OR, [ + Literal("test"), Literal("test2", negated=True), Literal("test3") + ]), + ]) + solved, result_cnf = solver.solve(cnf) + assert solved + assert result_cnf.assigned_literal_count() == 1 + +def test_multiple_literals_multiple_clauses_1(): + solver = DPLL() + cnf = CNF([ + Clause(ClauseType.OR, [ + Literal("a", negated=True), Literal("b"), Literal("c"), + ]), + Clause(ClauseType.OR, [ + Literal("a"), Literal("c"), Literal("d"), + ]), + Clause(ClauseType.OR, [ + Literal("a"), Literal("c"), Literal("d", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("a"), Literal("c", negated=True), Literal("d"), + ]), + Clause(ClauseType.OR, [ + Literal("a"), Literal("c", negated=True), Literal("d", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("b", negated=True), Literal("c", negated=True), Literal("d"), + ]), + Clause(ClauseType.OR, [ + Literal("a", negated=True), Literal("b"), Literal("c", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("a", negated=True), Literal("b", negated=True), Literal("c") + ]), + ]) + solved, result_cnf = solver.solve(cnf) + assert solved + assert result_cnf.assigned_literal_count() == result_cnf.unique_literal_count() + + a = result_cnf.get_literal("a") + b = result_cnf.get_literal("b") + c = result_cnf.get_literal("c") + d = result_cnf.get_literal("d") + + assert a and a.assignment == True + assert b and b.assignment == True + assert c and c.assignment == True + assert d and d.assignment == True + +def test_multiple_literals_multiple_clauses_2(): + solver = DPLL() + cnf = CNF([ + Clause(ClauseType.OR, [ + Literal("lib-1"), Literal("lib-2"), Literal("prog-1", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("lib-2"), Literal("prog-2", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("python-2"), Literal("lib-1", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("python-3"), Literal("lib-2", negated=True), + ]), + Clause(ClauseType.OR, [ + Literal("python-3", negated=True) + ]), + Clause(ClauseType.OR, [ + Literal("prog-1"), Literal("prog-2"), + ]), + Clause(ClauseType.AT_MOST_ONE, [ + Literal("prog-1"), Literal("prog-2"), + ]), + Clause(ClauseType.AT_MOST_ONE, [ + Literal("lib-1"), Literal("lib-2"), + ]), + Clause(ClauseType.AT_MOST_ONE, [ + Literal("python-2"), Literal("python-3"), + ]), + ]) + + solved, result_cnf = solver.solve(cnf) + assert solved + assert result_cnf.assigned_literal_count() == result_cnf.unique_literal_count() + + prog1 = result_cnf.get_literal("prog-1") + prog2 = result_cnf.get_literal("prog-2") + lib1 = result_cnf.get_literal("lib-1") + lib2 = result_cnf.get_literal("lib-2") + python2 = result_cnf.get_literal("python-2") + python3 = result_cnf.get_literal("python-3") + + assert prog1 and prog1.assignment == True + assert prog2 and prog2.assignment == False + assert lib1 and lib1.assignment == True + assert lib2 and lib2.assignment == False + assert python2 and python2.assignment == True + assert python3 and python3.assignment == False diff --git a/tests/test_hipaasat.py b/tests/test_hipaasat.py deleted file mode 100644 index 01c2d71..0000000 --- a/tests/test_hipaasat.py +++ /dev/null @@ -1,5 +0,0 @@ -from hipaasat import __version__ - - -def test_version(): - assert __version__ == '0.1.0'