Skip to content

Commit

Permalink
fix: formulae have repr method correctly implemented, minor bugs dete…
Browse files Browse the repository at this point in the history
…cted and solved. This closes #25
  • Loading branch information
MatteoMagnini committed Nov 18, 2022
1 parent 2bea82e commit 003e614
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 38 deletions.
16 changes: 13 additions & 3 deletions psyki/logic/datalog/fuzzifiers/lukasciewicz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations
from typing import Callable
from tensorflow.keras import Model
from tensorflow.python.framework.ops import convert_to_tensor

from psyki.logic.datalog.grammar import *
from tensorflow import cast, SparseTensor, maximum, minimum, constant, reshape, reduce_max, tile
from tensorflow import cast, SparseTensor, maximum, minimum, constant, reshape, reduce_max, tile, reduce_min
from tensorflow.python.keras.backend import to_dense
from tensorflow.python.ops.array_ops import shape
from psyki.logic import Formula, get_logic_symbols_with_short_name
Expand Down Expand Up @@ -54,6 +56,10 @@ def __init__(self, class_mapping: dict[str, int], feature_mapping: dict[str, int
'+': lambda l, r: lambda x: l(x) + r(x),
'*': lambda l, r: lambda x: l(x) * r(x)
}
self._aggregate_operation = {
_logic_symbols('cj'): lambda args: lambda x: eta(reduce_max(convert_to_tensor([arg(x) for arg in args]), axis=0)),
_logic_symbols('dj'): lambda args: lambda x: eta(reduce_min(convert_to_tensor([arg(x) for arg in args]), axis=0)),
}
self._implication = ''

@staticmethod
Expand Down Expand Up @@ -104,8 +110,12 @@ def _visit_definition_clause(self, node: DefinitionClause, rhs: Clause,
self._visit(rhs, m)(x))))

def _visit_expression(self, node: Expression, local_mapping: dict[str, int] = None) -> Callable:
l, r = self._visit(node.lhs, local_mapping), self._visit(node.rhs, local_mapping)
return self._operation.get(node.op)(l, r)
if len(node.nary) <= 2:
l, r = self._visit(node.lhs, local_mapping), self._visit(node.rhs, local_mapping)
return self._operation.get(node.op)(l, r)
else:
previous_layer = [self._visit(clause, local_mapping) for clause in node.nary]
return self._aggregate_operation.get(node.op)(previous_layer)

def _visit_variable(self, node: Variable, local_mapping: dict[str, int] = None):
if node.name in self.feature_mapping.keys():
Expand Down
105 changes: 77 additions & 28 deletions psyki/logic/datalog/grammar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# TODO: refactoring
def optimize_datalog_formula(formula: Formula):
def optimize_datalog_formula(formula: Formula) -> None:
if isinstance(formula, datalog.grammar.Expression):
lhs = formula.lhs
rhs = formula.rhs
Expand Down Expand Up @@ -52,7 +52,10 @@ def __init__(self, lhs: DefinitionClause, rhs: Clause, op: str = '<-'):
self.op: str = op

def __str__(self) -> str:
return str(self.lhs) + self.op + str(self.rhs)
return str(self.lhs) + ' ' + self.op + ' ' + str(self.rhs)

def __repr__(self) -> str:
return repr(self.lhs) + self.op + repr(self.rhs)

def copy(self) -> Formula:
return DatalogFormula(self.lhs.copy(), self.rhs.copy(), self.op)
Expand All @@ -64,10 +67,13 @@ def __init__(self, predication: str, arg: Argument):
self.predication: str = predication
self.arg: Argument = arg

def __repr__(self) -> str:
return self.predication + '(' + (repr(self.arg) if self.arg is not None else '') + ')'

def __str__(self) -> str:
return self.predication + '(' + str(self.arg) + ')'
return self.predication + '(' + (str(self.arg) if self.arg is not None else '') + ')'

def copy(self) -> Formula:
def copy(self) -> DefinitionClause:
return DefinitionClause(self.predication, self.arg)


Expand All @@ -79,20 +85,28 @@ def copy(self) -> Clause:

class Expression(Clause):

def __init__(self, lhs: Clause, rhs: Clause, op: str, nary: Iterable[Clause] = []):
def __init__(self, lhs: Clause, rhs: Clause, op: str, nary: Iterable[Clause] = None):
self.lhs: Clause = lhs
self.rhs: Clause = rhs
self.nary: list[Clause] = list(nary)
self.nary: list[Clause] = list(nary) if nary is not None else []
self.op: str = op

def __repr__(self) -> str:
if len(self.nary) <= 2:
return repr(self.lhs) + repr(self.op) + repr(self.rhs)
else:
return "'" + self.op + "'(" + ','.join(repr(clause) for clause in self.nary) + ')'

def __str__(self) -> str:
if len(self.nary) == 0:
return '(' + str(self.lhs) + ')' + self.op + '(' + str(self.rhs) + ')'
if len(self.nary) <= 2:
return str(self.lhs) + ('' if self.op == ',' else ' ') + str(self.op) + ' ' + str(self.rhs)
else:
return "'" + self.op + "'(" + ','.join(str(clause) for clause in self.nary) + ')'
return "'" + self.op + "'(" + ', '.join(str(clause) for clause in self.nary) + ')'

def copy(self) -> Formula:
return Expression(self.lhs.copy(), self.rhs.copy(), self.op, [c.copy() for c in self.nary])
def copy(self) -> Expression:
lhs = self.lhs.copy() if self.lhs is not None else None
rhs = self.rhs.copy() if self.rhs is not None else None
return Expression(lhs, rhs, self.op, [c.copy() for c in self.nary if c is not None])


class Literal(Clause, ABC):
Expand All @@ -104,26 +118,37 @@ class Negation(Literal):
def __init__(self, predicate: Clause):
self.predicate: Clause = predicate

def __repr__(self) -> str:
return 'not(' + repr(self.predicate) + ')'

def __str__(self) -> str:
return 'neg(' + str(self.predicate) + ')'
return 'not(' + str(self.predicate) + ')'

def copy(self) -> Formula:
def copy(self) -> Negation:
return Negation(self.predicate.copy())


class Predicate(Clause, ABC):
pass

name: str = ''

@property
def _name(self) -> str:
return self.name if self.name[0].islower() else "'" + self.name + "'"


class Unary(Predicate):

def __init__(self, name: str):
self.name: str = name

def __repr__(self) -> str:
return repr(self.name) + "()"

def __str__(self) -> str:
return self.name
return self._name + "()"

def copy(self) -> Formula:
def copy(self) -> Unary:
return Unary(self.name)


Expand All @@ -133,10 +158,13 @@ def __init__(self, name: str, m: Number, arg: ComplexArgument):
self.m: Number = m
self.arg: ComplexArgument = arg

def __repr__(self) -> str:
return repr(self.name) + '(' + repr(self.m) + ',' + repr(self.arg) + ')'

def __str__(self) -> str:
return self.name + '(' + str(self.m) + ', ' + str(self.arg) + ')'
return self._name + '(' + str(self.m) + ', ' + str(self.arg) + ')'

def copy(self) -> Formula:
def copy(self) -> MofN:
return MofN(self.name, self.m, self.arg)


Expand All @@ -146,10 +174,13 @@ def __init__(self, name: str, arg: Argument):
self.name: str = name
self.arg: Argument = arg

def __repr__(self):
return repr(self.name) + '(' + (repr(self.arg) if self.arg is not None else '') + ')'

def __str__(self) -> str:
return self.name + '(' + str(self.arg) + ')'
return self._name + '(' + (str(self.arg) if self.arg is not None else '') + ')'

def copy(self) -> Formula:
def copy(self) -> Nary:
return Nary(self.name, self.arg)


Expand All @@ -166,10 +197,13 @@ class Predication(Constant):
def __init__(self, name: str):
self.name: str = name

def __repr__(self) -> str:
return repr(self.name)

def __str__(self) -> str:
return self.name
return self._name

def copy(self) -> Formula:
def copy(self) -> Predication:
return Predication(self.name)


Expand All @@ -178,14 +212,17 @@ class Boolean(Constant):
def __init__(self, value: bool = True):
self.value: bool = value

def __repr__(self) -> str:
return repr(self.value)

def __str__(self) -> str:
return str(self.value)

@property
def is_true(self) -> bool:
return self.value

def copy(self) -> Formula:
def copy(self) -> Boolean:
return Boolean(self.value)


Expand All @@ -194,22 +231,28 @@ class Number(Constant):
def __init__(self, value: str):
self.value: float = float(value)

def __repr__(self) -> str:
return repr(self.value)

def __str__(self) -> str:
return str(self.value)

def copy(self) -> Formula:
return Number(self.value)
def copy(self) -> Number:
return Number(str(self.value))


class Variable(Term):

def __init__(self, name: str):
self.name: str = name

def __repr__(self) -> str:
return self.name

def __str__(self) -> str:
return self.name

def copy(self) -> Formula:
def copy(self) -> Variable:
return Variable(self.name)


Expand All @@ -219,10 +262,13 @@ def __init__(self, term: Term, arg: Argument = None):
self.term: Term = term
self.arg: Argument = arg

def __repr__(self) -> str:
return str(self.term) + (',' + str(self.arg) if self.arg is not None else '')

def __str__(self) -> str:
return str(self.term) + (',' + str(self.arg) if self.arg is not None else '')

def copy(self) -> Formula:
def copy(self) -> Argument:
return Argument(self.term.copy(), self.arg)

@property
Expand All @@ -243,10 +289,13 @@ def __init__(self, clause: Clause, arg: ComplexArgument = None):
self.clause: Clause = clause
self.arg: ComplexArgument = arg

def __repr__(self) -> str:
return str(self.clause) + (',' + str(self.arg) if self.arg is not None else '')

def __str__(self) -> str:
return str(self.clause) + (',' + str(self.arg) if self.arg is not None else '')

def copy(self) -> Formula:
def copy(self) -> ComplexArgument:
return ComplexArgument(self.clause.copy(), self.arg)

@property
Expand Down
6 changes: 4 additions & 2 deletions psyki/logic/datalog/grammar/adapters/antlr4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from os.path import isdir
from psyki.logic.datalog.grammar import DatalogFormula, Expression, DefinitionClause, Argument, Negation, Unary, Nary, \
Variable, Number, Predication, MofN, ComplexArgument
Variable, Number, Predication, MofN, ComplexArgument, optimize_datalog_formula
from psyki.resources import PATH, create_antlr4_parser
if not isdir(str(PATH / 'dist')):
create_antlr4_parser(str(PATH / 'Datalog.g4'), str(PATH / 'dist'))
Expand All @@ -14,7 +14,9 @@ def get_formula_from_string(rule: str) -> DatalogFormula:


def get_formula(ast: DatalogParser.FormulaContext) -> DatalogFormula:
return DatalogFormula(_get_definition_clause(ast.lhs), _get_clause(ast.rhs), ast.op.text)
formula: DatalogFormula = DatalogFormula(_get_definition_clause(ast.lhs), _get_clause(ast.rhs), ast.op.text)
optimize_datalog_formula(formula)
return formula


def to_prolog_string(rule: DatalogFormula) -> str:
Expand Down
4 changes: 2 additions & 2 deletions test/psyki/logic/datalog/test_formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


class TestFormula(unittest.TestCase):
virginica_rule = "class(PL,PW,SL,SW,virginica) <- PL > 2.28 , PW > 1.64"
expected_formula_structure = "class(PL,PW,SL,SW,virginica)<-((PL)>(2.28)),((PW)>(1.64))"
virginica_rule = "class(PL,PW,SL,SW,virginica) <- PL > 2.28, PW > 1.64"
expected_formula_structure = "class(PL,PW,SL,SW,virginica) <- PL > 2.28, PW > 1.64"

def test_parsing_with_antlr4(self):
formula = get_formula_from_string(self.virginica_rule)
Expand Down
6 changes: 3 additions & 3 deletions test/psyki/logic/tuppy/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class TestConversion(unittest.TestCase):

def test_prolog_from_text_to_datalog(self):
common_head = 'iris(PetalLength,PetalWidth,SepalLength,SepalWidth'
expected_result = [common_head + ',virginica)<-neg(((PetalWidth)>=(0.664341)),((PetalWidth)<(1.651423)))',
common_head + ',setosa)<-(PetalWidth)=<(1.651423)',
common_head + ',versicolor)<-True']
expected_result = [common_head + ',virginica) <- not(PetalWidth >= 0.664341, PetalWidth < 1.651423)',
common_head + ',setosa) <- PetalWidth =< 1.651423',
common_head + ',versicolor) <- True']
prolog_theory = file_to_prolog(PATH / self.iris_kb)
datalog_formulae = prolog_to_datalog(prolog_theory)
for i, formula in enumerate(datalog_formulae):
Expand Down

0 comments on commit 003e614

Please sign in to comment.