From 003e61401f48c7659ec27be50ec18c774f33ec04 Mon Sep 17 00:00:00 2001 From: Matteo Magnini Date: Fri, 18 Nov 2022 12:35:37 +0100 Subject: [PATCH] fix: formulae have repr method correctly implemented, minor bugs detected and solved. This closes #25 --- .../fuzzifiers/lukasciewicz/__init__.py | 16 ++- psyki/logic/datalog/grammar/__init__.py | 105 +++++++++++++----- .../logic/datalog/grammar/adapters/antlr4.py | 6 +- test/psyki/logic/datalog/test_formula.py | 4 +- test/psyki/logic/tuppy/test_conversion.py | 6 +- 5 files changed, 99 insertions(+), 38 deletions(-) diff --git a/psyki/logic/datalog/fuzzifiers/lukasciewicz/__init__.py b/psyki/logic/datalog/fuzzifiers/lukasciewicz/__init__.py index bd1be68..de30dbb 100644 --- a/psyki/logic/datalog/fuzzifiers/lukasciewicz/__init__.py +++ b/psyki/logic/datalog/fuzzifiers/lukasciewicz/__init__.py @@ -1,8 +1,10 @@ from __future__ import annotations from typing import Callable from tensorflow.keras import Model +from tensorflow.python.framework.ops import convert_to_tensor + from psyki.logic.datalog.grammar import * -from tensorflow import cast, SparseTensor, maximum, minimum, constant, reshape, reduce_max, tile +from tensorflow import cast, SparseTensor, maximum, minimum, constant, reshape, reduce_max, tile, reduce_min from tensorflow.python.keras.backend import to_dense from tensorflow.python.ops.array_ops import shape from psyki.logic import Formula, get_logic_symbols_with_short_name @@ -54,6 +56,10 @@ def __init__(self, class_mapping: dict[str, int], feature_mapping: dict[str, int '+': lambda l, r: lambda x: l(x) + r(x), '*': lambda l, r: lambda x: l(x) * r(x) } + self._aggregate_operation = { + _logic_symbols('cj'): lambda args: lambda x: eta(reduce_max(convert_to_tensor([arg(x) for arg in args]), axis=0)), + _logic_symbols('dj'): lambda args: lambda x: eta(reduce_min(convert_to_tensor([arg(x) for arg in args]), axis=0)), + } self._implication = '' @staticmethod @@ -104,8 +110,12 @@ def _visit_definition_clause(self, node: DefinitionClause, rhs: Clause, self._visit(rhs, m)(x)))) def _visit_expression(self, node: Expression, local_mapping: dict[str, int] = None) -> Callable: - l, r = self._visit(node.lhs, local_mapping), self._visit(node.rhs, local_mapping) - return self._operation.get(node.op)(l, r) + if len(node.nary) <= 2: + l, r = self._visit(node.lhs, local_mapping), self._visit(node.rhs, local_mapping) + return self._operation.get(node.op)(l, r) + else: + previous_layer = [self._visit(clause, local_mapping) for clause in node.nary] + return self._aggregate_operation.get(node.op)(previous_layer) def _visit_variable(self, node: Variable, local_mapping: dict[str, int] = None): if node.name in self.feature_mapping.keys(): diff --git a/psyki/logic/datalog/grammar/__init__.py b/psyki/logic/datalog/grammar/__init__.py index 8d19b6e..afcf508 100644 --- a/psyki/logic/datalog/grammar/__init__.py +++ b/psyki/logic/datalog/grammar/__init__.py @@ -6,7 +6,7 @@ # TODO: refactoring -def optimize_datalog_formula(formula: Formula): +def optimize_datalog_formula(formula: Formula) -> None: if isinstance(formula, datalog.grammar.Expression): lhs = formula.lhs rhs = formula.rhs @@ -52,7 +52,10 @@ def __init__(self, lhs: DefinitionClause, rhs: Clause, op: str = '<-'): self.op: str = op def __str__(self) -> str: - return str(self.lhs) + self.op + str(self.rhs) + return str(self.lhs) + ' ' + self.op + ' ' + str(self.rhs) + + def __repr__(self) -> str: + return repr(self.lhs) + self.op + repr(self.rhs) def copy(self) -> Formula: return DatalogFormula(self.lhs.copy(), self.rhs.copy(), self.op) @@ -64,10 +67,13 @@ def __init__(self, predication: str, arg: Argument): self.predication: str = predication self.arg: Argument = arg + def __repr__(self) -> str: + return self.predication + '(' + (repr(self.arg) if self.arg is not None else '') + ')' + def __str__(self) -> str: - return self.predication + '(' + str(self.arg) + ')' + return self.predication + '(' + (str(self.arg) if self.arg is not None else '') + ')' - def copy(self) -> Formula: + def copy(self) -> DefinitionClause: return DefinitionClause(self.predication, self.arg) @@ -79,20 +85,28 @@ def copy(self) -> Clause: class Expression(Clause): - def __init__(self, lhs: Clause, rhs: Clause, op: str, nary: Iterable[Clause] = []): + def __init__(self, lhs: Clause, rhs: Clause, op: str, nary: Iterable[Clause] = None): self.lhs: Clause = lhs self.rhs: Clause = rhs - self.nary: list[Clause] = list(nary) + self.nary: list[Clause] = list(nary) if nary is not None else [] self.op: str = op + def __repr__(self) -> str: + if len(self.nary) <= 2: + return repr(self.lhs) + repr(self.op) + repr(self.rhs) + else: + return "'" + self.op + "'(" + ','.join(repr(clause) for clause in self.nary) + ')' + def __str__(self) -> str: - if len(self.nary) == 0: - return '(' + str(self.lhs) + ')' + self.op + '(' + str(self.rhs) + ')' + if len(self.nary) <= 2: + return str(self.lhs) + ('' if self.op == ',' else ' ') + str(self.op) + ' ' + str(self.rhs) else: - return "'" + self.op + "'(" + ','.join(str(clause) for clause in self.nary) + ')' + return "'" + self.op + "'(" + ', '.join(str(clause) for clause in self.nary) + ')' - def copy(self) -> Formula: - return Expression(self.lhs.copy(), self.rhs.copy(), self.op, [c.copy() for c in self.nary]) + def copy(self) -> Expression: + lhs = self.lhs.copy() if self.lhs is not None else None + rhs = self.rhs.copy() if self.rhs is not None else None + return Expression(lhs, rhs, self.op, [c.copy() for c in self.nary if c is not None]) class Literal(Clause, ABC): @@ -104,15 +118,23 @@ class Negation(Literal): def __init__(self, predicate: Clause): self.predicate: Clause = predicate + def __repr__(self) -> str: + return 'not(' + repr(self.predicate) + ')' + def __str__(self) -> str: - return 'neg(' + str(self.predicate) + ')' + return 'not(' + str(self.predicate) + ')' - def copy(self) -> Formula: + def copy(self) -> Negation: return Negation(self.predicate.copy()) class Predicate(Clause, ABC): - pass + + name: str = '' + + @property + def _name(self) -> str: + return self.name if self.name[0].islower() else "'" + self.name + "'" class Unary(Predicate): @@ -120,10 +142,13 @@ class Unary(Predicate): def __init__(self, name: str): self.name: str = name + def __repr__(self) -> str: + return repr(self.name) + "()" + def __str__(self) -> str: - return self.name + return self._name + "()" - def copy(self) -> Formula: + def copy(self) -> Unary: return Unary(self.name) @@ -133,10 +158,13 @@ def __init__(self, name: str, m: Number, arg: ComplexArgument): self.m: Number = m self.arg: ComplexArgument = arg + def __repr__(self) -> str: + return repr(self.name) + '(' + repr(self.m) + ',' + repr(self.arg) + ')' + def __str__(self) -> str: - return self.name + '(' + str(self.m) + ', ' + str(self.arg) + ')' + return self._name + '(' + str(self.m) + ', ' + str(self.arg) + ')' - def copy(self) -> Formula: + def copy(self) -> MofN: return MofN(self.name, self.m, self.arg) @@ -146,10 +174,13 @@ def __init__(self, name: str, arg: Argument): self.name: str = name self.arg: Argument = arg + def __repr__(self): + return repr(self.name) + '(' + (repr(self.arg) if self.arg is not None else '') + ')' + def __str__(self) -> str: - return self.name + '(' + str(self.arg) + ')' + return self._name + '(' + (str(self.arg) if self.arg is not None else '') + ')' - def copy(self) -> Formula: + def copy(self) -> Nary: return Nary(self.name, self.arg) @@ -166,10 +197,13 @@ class Predication(Constant): def __init__(self, name: str): self.name: str = name + def __repr__(self) -> str: + return repr(self.name) + def __str__(self) -> str: - return self.name + return self._name - def copy(self) -> Formula: + def copy(self) -> Predication: return Predication(self.name) @@ -178,6 +212,9 @@ class Boolean(Constant): def __init__(self, value: bool = True): self.value: bool = value + def __repr__(self) -> str: + return repr(self.value) + def __str__(self) -> str: return str(self.value) @@ -185,7 +222,7 @@ def __str__(self) -> str: def is_true(self) -> bool: return self.value - def copy(self) -> Formula: + def copy(self) -> Boolean: return Boolean(self.value) @@ -194,11 +231,14 @@ class Number(Constant): def __init__(self, value: str): self.value: float = float(value) + def __repr__(self) -> str: + return repr(self.value) + def __str__(self) -> str: return str(self.value) - def copy(self) -> Formula: - return Number(self.value) + def copy(self) -> Number: + return Number(str(self.value)) class Variable(Term): @@ -206,10 +246,13 @@ class Variable(Term): def __init__(self, name: str): self.name: str = name + def __repr__(self) -> str: + return self.name + def __str__(self) -> str: return self.name - def copy(self) -> Formula: + def copy(self) -> Variable: return Variable(self.name) @@ -219,10 +262,13 @@ def __init__(self, term: Term, arg: Argument = None): self.term: Term = term self.arg: Argument = arg + def __repr__(self) -> str: + return str(self.term) + (',' + str(self.arg) if self.arg is not None else '') + def __str__(self) -> str: return str(self.term) + (',' + str(self.arg) if self.arg is not None else '') - def copy(self) -> Formula: + def copy(self) -> Argument: return Argument(self.term.copy(), self.arg) @property @@ -243,10 +289,13 @@ def __init__(self, clause: Clause, arg: ComplexArgument = None): self.clause: Clause = clause self.arg: ComplexArgument = arg + def __repr__(self) -> str: + return str(self.clause) + (',' + str(self.arg) if self.arg is not None else '') + def __str__(self) -> str: return str(self.clause) + (',' + str(self.arg) if self.arg is not None else '') - def copy(self) -> Formula: + def copy(self) -> ComplexArgument: return ComplexArgument(self.clause.copy(), self.arg) @property diff --git a/psyki/logic/datalog/grammar/adapters/antlr4.py b/psyki/logic/datalog/grammar/adapters/antlr4.py index bbfcaf7..997c5be 100644 --- a/psyki/logic/datalog/grammar/adapters/antlr4.py +++ b/psyki/logic/datalog/grammar/adapters/antlr4.py @@ -1,6 +1,6 @@ from os.path import isdir from psyki.logic.datalog.grammar import DatalogFormula, Expression, DefinitionClause, Argument, Negation, Unary, Nary, \ - Variable, Number, Predication, MofN, ComplexArgument + Variable, Number, Predication, MofN, ComplexArgument, optimize_datalog_formula from psyki.resources import PATH, create_antlr4_parser if not isdir(str(PATH / 'dist')): create_antlr4_parser(str(PATH / 'Datalog.g4'), str(PATH / 'dist')) @@ -14,7 +14,9 @@ def get_formula_from_string(rule: str) -> DatalogFormula: def get_formula(ast: DatalogParser.FormulaContext) -> DatalogFormula: - return DatalogFormula(_get_definition_clause(ast.lhs), _get_clause(ast.rhs), ast.op.text) + formula: DatalogFormula = DatalogFormula(_get_definition_clause(ast.lhs), _get_clause(ast.rhs), ast.op.text) + optimize_datalog_formula(formula) + return formula def to_prolog_string(rule: DatalogFormula) -> str: diff --git a/test/psyki/logic/datalog/test_formula.py b/test/psyki/logic/datalog/test_formula.py index adf4e55..6613519 100644 --- a/test/psyki/logic/datalog/test_formula.py +++ b/test/psyki/logic/datalog/test_formula.py @@ -3,8 +3,8 @@ class TestFormula(unittest.TestCase): - virginica_rule = "class(PL,PW,SL,SW,virginica) <- PL > 2.28 , PW > 1.64" - expected_formula_structure = "class(PL,PW,SL,SW,virginica)<-((PL)>(2.28)),((PW)>(1.64))" + virginica_rule = "class(PL,PW,SL,SW,virginica) <- PL > 2.28, PW > 1.64" + expected_formula_structure = "class(PL,PW,SL,SW,virginica) <- PL > 2.28, PW > 1.64" def test_parsing_with_antlr4(self): formula = get_formula_from_string(self.virginica_rule) diff --git a/test/psyki/logic/tuppy/test_conversion.py b/test/psyki/logic/tuppy/test_conversion.py index c3297b0..a417de1 100644 --- a/test/psyki/logic/tuppy/test_conversion.py +++ b/test/psyki/logic/tuppy/test_conversion.py @@ -10,9 +10,9 @@ class TestConversion(unittest.TestCase): def test_prolog_from_text_to_datalog(self): common_head = 'iris(PetalLength,PetalWidth,SepalLength,SepalWidth' - expected_result = [common_head + ',virginica)<-neg(((PetalWidth)>=(0.664341)),((PetalWidth)<(1.651423)))', - common_head + ',setosa)<-(PetalWidth)=<(1.651423)', - common_head + ',versicolor)<-True'] + expected_result = [common_head + ',virginica) <- not(PetalWidth >= 0.664341, PetalWidth < 1.651423)', + common_head + ',setosa) <- PetalWidth =< 1.651423', + common_head + ',versicolor) <- True'] prolog_theory = file_to_prolog(PATH / self.iris_kb) datalog_formulae = prolog_to_datalog(prolog_theory) for i, formula in enumerate(datalog_formulae):