From 8ca6a047217977e18a564c385e3ca4bb5e403d4c Mon Sep 17 00:00:00 2001 From: Amade Nemes Date: Sat, 15 Jul 2023 12:34:01 +0200 Subject: [PATCH] Finish refactoring and extending test to 100% coverage. --- src/renopro/asp/encodings/transform.lp | 6 +- .../reify/{ => good_ast}/binary_operation.lp | 0 .../reify/{ => good_ast}/constant_term.lp | 0 .../tests/reify/{ => good_ast}/external.lp | 0 .../tests/reify/{ => good_ast}/function.lp | 0 .../reify/{ => good_ast}/nested_function.lp | 0 .../asp/tests/reify/good_ast/program_acid.lp | 7 + .../tests/reify/{ => good_ast}/prop_fact.lp | 0 .../reify/{ => good_ast}/prop_normal_rule.lp | 0 .../asp/tests/reify/{ => good_ast}/string.lp | 0 .../tests/reify/{ => good_ast}/variable.lp | 0 .../asp/tests/reify/malformed_ast/ast_fact.lp | 1 + .../reify/malformed_ast/missing_child.lp | 8 + .../reify/malformed_ast/multiple_child.lp | 11 + .../reify/malformed_ast/not_an_ast_fact.lp | 1 + .../asp/tests/transform/not_bad/transform.lp | 3 +- .../transform/prev_to_timepoints/transform.lp | 23 +- src/renopro/predicates.py | 4 +- src/renopro/reify.py | 346 ++++++++++-------- tests/test_reify_reflect.py | 209 ++++++++--- tests/test_transform.py | 10 + 21 files changed, 396 insertions(+), 233 deletions(-) rename src/renopro/asp/tests/reify/{ => good_ast}/binary_operation.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/constant_term.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/external.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/function.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/nested_function.lp (100%) create mode 100644 src/renopro/asp/tests/reify/good_ast/program_acid.lp rename src/renopro/asp/tests/reify/{ => good_ast}/prop_fact.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/prop_normal_rule.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/string.lp (100%) rename src/renopro/asp/tests/reify/{ => good_ast}/variable.lp (100%) create mode 100644 src/renopro/asp/tests/reify/malformed_ast/ast_fact.lp create mode 100644 src/renopro/asp/tests/reify/malformed_ast/missing_child.lp create mode 100644 src/renopro/asp/tests/reify/malformed_ast/multiple_child.lp create mode 100644 src/renopro/asp/tests/reify/malformed_ast/not_an_ast_fact.lp diff --git a/src/renopro/asp/encodings/transform.lp b/src/renopro/asp/encodings/transform.lp index 0852e3c..0714013 100644 --- a/src/renopro/asp/encodings/transform.lp +++ b/src/renopro/asp/encodings/transform.lp @@ -15,8 +15,8 @@ ast(constant_tuple(Id,Pos,Element)) :- constant_tuple(Id,Pos,Element). ast(program(Name,Params,Statements)) :- program(Name,Params,Statements). -final(A) :- ast(A), not delete(A), not replace(A,_). -final(R) :- replace(_,R). -final(N) :- add(N). +final(A) :- ast(A), not ast_operation(delete(A)), not ast_operation(replace(A,_)). +final(R) :- ast_operation(replace(_,R)). +final(N) :- ast_operation(add(N)). #show final/1. diff --git a/src/renopro/asp/tests/reify/binary_operation.lp b/src/renopro/asp/tests/reify/good_ast/binary_operation.lp similarity index 100% rename from src/renopro/asp/tests/reify/binary_operation.lp rename to src/renopro/asp/tests/reify/good_ast/binary_operation.lp diff --git a/src/renopro/asp/tests/reify/constant_term.lp b/src/renopro/asp/tests/reify/good_ast/constant_term.lp similarity index 100% rename from src/renopro/asp/tests/reify/constant_term.lp rename to src/renopro/asp/tests/reify/good_ast/constant_term.lp diff --git a/src/renopro/asp/tests/reify/external.lp b/src/renopro/asp/tests/reify/good_ast/external.lp similarity index 100% rename from src/renopro/asp/tests/reify/external.lp rename to src/renopro/asp/tests/reify/good_ast/external.lp diff --git a/src/renopro/asp/tests/reify/function.lp b/src/renopro/asp/tests/reify/good_ast/function.lp similarity index 100% rename from src/renopro/asp/tests/reify/function.lp rename to src/renopro/asp/tests/reify/good_ast/function.lp diff --git a/src/renopro/asp/tests/reify/nested_function.lp b/src/renopro/asp/tests/reify/good_ast/nested_function.lp similarity index 100% rename from src/renopro/asp/tests/reify/nested_function.lp rename to src/renopro/asp/tests/reify/good_ast/nested_function.lp diff --git a/src/renopro/asp/tests/reify/good_ast/program_acid.lp b/src/renopro/asp/tests/reify/good_ast/program_acid.lp new file mode 100644 index 0000000..7c1283e --- /dev/null +++ b/src/renopro/asp/tests/reify/good_ast/program_acid.lp @@ -0,0 +1,7 @@ +% reified fact representation of program: +% #program acid(k). + +program("base",constant_tuple(0),statement_tuple(1)). +program("acid",constant_tuple(2),statement_tuple(5)). +constant_tuple(2,0,function(3)). +function(3,k,term_tuple(4)). diff --git a/src/renopro/asp/tests/reify/prop_fact.lp b/src/renopro/asp/tests/reify/good_ast/prop_fact.lp similarity index 100% rename from src/renopro/asp/tests/reify/prop_fact.lp rename to src/renopro/asp/tests/reify/good_ast/prop_fact.lp diff --git a/src/renopro/asp/tests/reify/prop_normal_rule.lp b/src/renopro/asp/tests/reify/good_ast/prop_normal_rule.lp similarity index 100% rename from src/renopro/asp/tests/reify/prop_normal_rule.lp rename to src/renopro/asp/tests/reify/good_ast/prop_normal_rule.lp diff --git a/src/renopro/asp/tests/reify/string.lp b/src/renopro/asp/tests/reify/good_ast/string.lp similarity index 100% rename from src/renopro/asp/tests/reify/string.lp rename to src/renopro/asp/tests/reify/good_ast/string.lp diff --git a/src/renopro/asp/tests/reify/variable.lp b/src/renopro/asp/tests/reify/good_ast/variable.lp similarity index 100% rename from src/renopro/asp/tests/reify/variable.lp rename to src/renopro/asp/tests/reify/good_ast/variable.lp diff --git a/src/renopro/asp/tests/reify/malformed_ast/ast_fact.lp b/src/renopro/asp/tests/reify/malformed_ast/ast_fact.lp new file mode 100644 index 0000000..621b233 --- /dev/null +++ b/src/renopro/asp/tests/reify/malformed_ast/ast_fact.lp @@ -0,0 +1 @@ +atom(12,function(13)). diff --git a/src/renopro/asp/tests/reify/malformed_ast/missing_child.lp b/src/renopro/asp/tests/reify/malformed_ast/missing_child.lp new file mode 100644 index 0000000..e923ef0 --- /dev/null +++ b/src/renopro/asp/tests/reify/malformed_ast/missing_child.lp @@ -0,0 +1,8 @@ +% Malformed set of ast facts with a missing child fact + +program("base",constant_tuple(0),statement_tuple(1)). +statement_tuple(1,0,rule(2)). +rule(2,literal(3),literal_tuple(7)). +% facts representing head literal a. +literal(3,"pos",atom(4)). +atom(4,function(5)). % child fact function(5,_,_) is missing but required diff --git a/src/renopro/asp/tests/reify/malformed_ast/multiple_child.lp b/src/renopro/asp/tests/reify/malformed_ast/multiple_child.lp new file mode 100644 index 0000000..7238b35 --- /dev/null +++ b/src/renopro/asp/tests/reify/malformed_ast/multiple_child.lp @@ -0,0 +1,11 @@ +% Malformed set of ast facts with a missing child fact + +program("base",constant_tuple(0),statement_tuple(1)). +statement_tuple(1,0,rule(2)). +rule(2,literal(3),literal_tuple(7)). +% facts representing head literal a. +literal(3,"pos",atom(4)). +atom(4,function(5)). +% should only have one child like function(5,_,_) +function(5,a,term_tuple(6)). +function(5,b,term_tuple(7)). diff --git a/src/renopro/asp/tests/reify/malformed_ast/not_an_ast_fact.lp b/src/renopro/asp/tests/reify/malformed_ast/not_an_ast_fact.lp new file mode 100644 index 0000000..e7b299d --- /dev/null +++ b/src/renopro/asp/tests/reify/malformed_ast/not_an_ast_fact.lp @@ -0,0 +1 @@ +notafact(dog, "mobius"). diff --git a/src/renopro/asp/tests/transform/not_bad/transform.lp b/src/renopro/asp/tests/transform/not_bad/transform.lp index 9e6b8ba..2f4c53d 100644 --- a/src/renopro/asp/tests/transform/not_bad/transform.lp +++ b/src/renopro/asp/tests/transform/not_bad/transform.lp @@ -19,10 +19,11 @@ max_lit_index(LT,Idx) % the transformation itself +ast_operation( add((literal_tuple(LT,N+1,literal(new_id(LT,0))); literal(new_id(LT,0),"not",atom(new_id(LT,1))); atom(new_id(LT,1),function(new_id(LT,2))); - function(new_id(LT,2),bad,Fargs))) + function(new_id(LT,2),bad,Fargs)))) :- rule(_,literal(L),literal_tuple(LT)), literal(L,"pos",atom(A)), atom(A,function(F)), diff --git a/src/renopro/asp/tests/transform/prev_to_timepoints/transform.lp b/src/renopro/asp/tests/transform/prev_to_timepoints/transform.lp index c5766d4..fda59a3 100644 --- a/src/renopro/asp/tests/transform/prev_to_timepoints/transform.lp +++ b/src/renopro/asp/tests/transform/prev_to_timepoints/transform.lp @@ -25,8 +25,9 @@ first_prev(A,F) :- prev_chain(A,F,O), not prev_chain(A,_,F). % when an atom is not prev. % add time point constant as additional argument -add(term_tuple(T,N+1,function(new_id(T))); - function(new_id(T),t,term_tuple(new_id(T)))) +ast_operation( + add(term_tuple(T,N+1,function(new_id(T))); + function(new_id(T),t,term_tuple(new_id(T))))) :- atom(A,function(F)), function(F,Name,term_tuple(T)), Name!=prev, max_arg_index(T,N). @@ -34,14 +35,12 @@ add(term_tuple(T,N+1,function(new_id(T))); % replace function symbol of atom with final operand, appending % appropriate time point as additional argument. -replace(function(F,N,T1),function(F,Name,term_tuple(T2))) +ast_operation( + delete(function(F,N,T1)); + add(function(F,Name,term_tuple(T2)); + term_tuple(T2,I+1,binary_operation(new_id(O))); + binary_operation(new_id(O),"-",function(new_id(O)),number(new_id(O))); + function(new_id(O),t,term_tuple(new_id(O))); + number(new_id(O),Num))) :- first_prev(A,function(F)), function(F,N,T1), final_operand(A,function(O)), - function(O,Name,term_tuple(T2)). - -add((term_tuple(T,I+1,binary_operation(new_id(O))); - binary_operation(new_id(O),"-",function(new_id(O)),number(new_id(O))); - function(new_id(O),t,term_tuple(new_id(O))); - number(new_id(O),N))) -:- final_operand(A,function(O)), function(O,_,term_tuple(T)), max_arg_index(T,I), - num_prevs(A,N). - + function(O,Name,term_tuple(T2)), max_arg_index(T2,I), num_prevs(A,Num). diff --git a/src/renopro/predicates.py b/src/renopro/predicates.py index 857840a..d64a63d 100644 --- a/src/renopro/predicates.py +++ b/src/renopro/predicates.py @@ -12,14 +12,12 @@ Predicate, RawField, StringField, - Unifier, combine_fields, define_enum_field, refine_field, ) id_count = count() -next_id = lambda: next(id_count) # by default we use integer identifiers, but allow arbitrary symbols as well # for flexibility when these are user generated @@ -38,7 +36,7 @@ def make_id_predicate(ast_pred): id_pred_name, (Predicate,), { - "id": Identifier_Field(default=next_id), + "id": Identifier_Field(default=lambda: next(id_count)), "Meta": type("Meta", tuple(), {"name": ast_pred.meta.name}), }, ) diff --git a/src/renopro/reify.py b/src/renopro/reify.py index 146eb05..ab37003 100644 --- a/src/renopro/reify.py +++ b/src/renopro/reify.py @@ -4,7 +4,9 @@ import sys from functools import singledispatchmethod from pathlib import Path -from typing import Iterable, Optional, Sequence, Union +from typing import Iterable, Optional, Sequence, Union, List +import inspect +from contextlib import contextmanager, AbstractContextManager from clingo import Control, ast, symbol from clingo.ast import AST, ASTType, Location, Position, parse_files, parse_string @@ -22,123 +24,142 @@ import renopro.predicates as preds +logger = logging.getLogger(__name__) + +DUMMY_LOC = Location(Position("", 1, 1), Position("", 1, 1)) + + class ChildQueryError(Exception): - pass + """Exception raised when a required child fact of an AST fact + cannot be found. + + """ class ChildrenQueryError(Exception): - pass - - -def raise_unmatched_ast_fact(error: UnifierNoMatchError): - """Enhance UnifierNoMatchError with some more useful error - messages to help debug the reason unification failed.""" - unmatched = error.symbol - name2arity2pred = { - pred.meta.name: {pred.meta.arity: pred} for pred in preds.AST_Facts - } - candidate = name2arity2pred.get(unmatched.name, dict()).get( - len(unmatched.arguments) - ) - if candidate is None: - fuzzy_name = process.extractOne(unmatched.name, name2arity2pred.keys())[0] - signatures = [f"{fuzzy_name}/{arity}." for arity in name2arity2pred[fuzzy_name]] - raise UnifierNoMatchError( - ( - "No AST fact of matching signature found for symbol\n" - f"'{unmatched}'.\nSimilar AST fact signatures are:\n" - + "\n".join(signatures) - ), - unmatched, - error.predicates, - ) from None - else: + """Exception raised when the expected number child facts of an AST + fact cannot be found. + + """ + + +class TryUnify(AbstractContextManager): + """Context manager to try some operation that requires unification + of some set of ast facts. Enhance error message if unification fails""" + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is UnifierNoMatchError: + self.handle_unify_error(exc_value) + + @staticmethod + def handle_unify_error(error): + """Enhance UnifierNoMatchError with some more + useful error messages to help debug the reason unification failed. + + """ + unmatched = error.symbol + name2arity2pred = { + pred.meta.name: {pred.meta.arity: pred} for pred in preds.AST_Facts + } + candidate = name2arity2pred.get(unmatched.name, {}).get(len(unmatched.arguments)) + if candidate is None: + fuzzy_name = process.extractOne(unmatched.name, name2arity2pred.keys())[0] + signatures = [f"{fuzzy_name}/{arity}." for arity in name2arity2pred[fuzzy_name]] + msg = f"""No AST fact of matching signature found for symbol + '{unmatched}'. + Similar AST fact signatures are: + """ + "\n".join(signatures) + raise UnifierNoMatchError( + inspect.cleandoc(msg), + unmatched, + error.predicates, + ) from None for idx, arg in enumerate(unmatched.arguments): + # This is very hacky. Should ask Dave for a better solution, if there is one. arg_field = candidate[idx]._field - arg_field_str = re.sub("\(.*?\)", "", str(arg_field)) + arg_field_str = re.sub(r"\(.*?\)", "", str(arg_field)) try: arg_field.cltopy(arg) except (TypeError, ValueError): - msg = ( - f"\nCannot unify symbol\n'{unmatched}'\nto only " - "candidate AST fact of matching signature " - f"{candidate.meta.name}/{candidate.meta.arity}\n" - f"due to failure to unify symbol's argument '{arg}' " - f"against the corresponding field '{arg_field_str}'." - ) + msg = f"""Cannot unify symbol + '{unmatched}' + to only candidate AST fact of matching signature + {candidate.meta.name}/{candidate.meta.arity} + due to failure to unify symbol's argument + '{arg}' + against the corresponding field + '{arg_field_str}'.""" raise UnifierNoMatchError( - msg, + inspect.cleandoc(msg), unmatched, (candidate,), ) from None raise RuntimeError("Code should be unreachable") # nocoverage -logger = logging.getLogger(__name__) - -DUMMY_LOC = Location(Position("", 1, 1), Position("", 1, 1)) - - class ReifiedAST: """Class for converting between reified and non-reified representation of ASP programs.""" def __init__(self): self._reified = FactBase() - self._program_statements = list() + self._program_ast = [] self._id_counter = -1 def add_reified_facts(self, reified_facts: Iterable[preds.AST_Predicate]) -> None: - """Add factbase containing reified facts into internal factbase.""" - self._reified.update(reified_facts) + """Add factbase containing reified AST facts into internal factbase.""" + unifier = Unifier(preds.AST_Facts) + # couldn't find a way in clorm to directly add a set of facts + # while checking unification, so we have to unify against the + # underlying symbols + with TryUnify(): + reified_facts = unifier.iter_unify( + [fact.symbol for fact in reified_facts], raise_nomatch=True + ) + self._reified.update(reified_facts) def add_reified_string(self, reified_string: str) -> None: """Add string of reified facts into internal factbase.""" unifier = preds.AST_Facts - try: + with TryUnify(): facts = parse_fact_string( reified_string, unifier=unifier, raise_nomatch=True, raise_nonfact=True ) - except UnifierNoMatchError as e: - raise_unmatched_ast_fact(e) self._reified.update(facts) def add_reified_files(self, reified_files: Sequence[Path]) -> None: """Add files containing reified facts into internal factbase.""" reified_files = [str(f) for f in reified_files] - try: + with TryUnify(): facts = parse_fact_files( reified_files, unifier=preds.AST_Facts, raise_nomatch=True, raise_nonfact=True, ) - except UnifierNoMatchError as e: - raise_unmatched_ast_fact(e) self._reified.update(facts) def reify_string(self, prog_str: str) -> None: """Reify input program string, adding reified facts to the internal factbase.""" - parse_string(prog_str, self._reify_ast) + parse_string(prog_str, self.reify_node) def reify_files(self, files: Sequence[Path]) -> None: """Reify input program files, adding reified facts to the internal factbase.""" for f in files: - if not f.is_file(): + if not f.is_file(): # nocoverage raise IOError(f"File {f} does not exist.") files = [str(f) for f in files] - self._program_statements = list() - parse_files(files, self._reify_ast) + parse_files(files, self.reify_node) @property def program_string(self) -> str: - return "\n".join([str(statement) for statement in self._program_statements]) + return "\n".join([str(statement) for statement in self._program_ast]) @property - def program_ast(self) -> str: - return self._program_statements + def program_ast(self) -> List[AST]: + return self._program_ast @property def reified_facts(self) -> FactBase: @@ -149,7 +170,7 @@ def reified_string(self) -> str: return self._reified.asp_str() @property - def reified_string_doc(self) -> str: + def reified_string_doc(self) -> str: # nocoverage return self._reified.asp_str(commented=True) @staticmethod @@ -178,10 +199,8 @@ def wrapper(self, node, *args, **kw): return dispatch(node.ast_type)(self, node, *args, **kw) elif isinstance(node, Symbol): return dispatch(node.type)(self, node, *args, **kw) - else: # nocoverage - raise RuntimeError( - ("Nodes should be of type AST or Symbol, " f"got: {type(node)}") - ) + else: + return dispatch(type(node))(self, node, *args, **kw) wrapper.register = register wrapper.dispatch = dispatch @@ -190,58 +209,61 @@ def wrapper(self, node, *args, **kw): return wrapper @dispatch_on_node_type - def _reify_ast(self, node): + def reify_node(self, node): """Reify the input ast node by adding it's clorm fact - representation to the internal fact base. + representation to the internal fact base, and recursively + reify child nodes. """ - if hasattr(node, "ast_type"): + if hasattr(node, "ast_type"): # nocoverage raise NotImplementedError( ( "Reification not implemented for AST nodes of type: " "f{node.ast_type.name}." ) ) - elif hasattr(node, "type"): + elif hasattr(node, "type"): # nocoverage raise NotImplementedError( ( "Reification not implemented for symbol of type: " "f{node.typle.name}." ) ) - else: # nocoverage - raise RuntimeError("Code block should be unreachable.") + else: + raise TypeError(f"Nodes should be of type AST or Symbol, got: {type(node)}") - @_reify_ast.register(ASTType.Program) + @reify_node.register(ASTType.Program) def _reify_program(self, node): - self._program_statements.append(node) + self._program_ast.append(node) + const_tup_id = preds.Constant_Tuple1() + for pos, param in enumerate(node.parameters): + const_tup = preds.Constant_Tuple( + id=const_tup_id.id, + position=pos, + element=preds.Function1() + ) + const = preds.Function(id=const_tup.element.id, + name=param.name, + arguments=preds.Term_Tuple1()) + self._reified.add([const_tup, const]) program = preds.Program( name=node.name, - parameters=preds.Constant_Tuple1(), + parameters=const_tup_id, statements=preds.Statement_Tuple1(), ) self._reified.add(program) self._statement_tup_id = program.statements.id self._statement_pos = 0 - for pos, param in enumerate(node.parameters): - self._reified.add( - preds.Constant_Tuple( - id=program.parameters.id, - position=pos, - # note: this id refers to the clingo.ast.Id.id attribute - element=param.id, - ) - ) return - @_reify_ast.register(ASTType.External) + @reify_node.register(ASTType.External) def _reify_external(self, node): - self._program_statements.append(node) + self._program_ast.append(node) ext_type = node.external_type.symbol.name external1 = preds.External1() external = preds.External( id=external1.id, - atom=self._reify_ast(node.atom), + atom=self.reify_node(node.atom), body=preds.Literal_Tuple1(), external_type=ext_type, ) @@ -254,18 +276,18 @@ def _reify_external(self, node): for pos, element in enumerate(node.body, start=0): self._reified.add( preds.Literal_Tuple( - id=external.body.id, position=pos, element=self._reify_ast(element) + id=external.body.id, position=pos, element=self.reify_node(element) ) ) return - @_reify_ast.register(ASTType.Rule) + @reify_node.register(ASTType.Rule) def _reify_rule(self, node): - self._program_statements.append(node) + self._program_ast.append(node) rule1 = preds.Rule1() # assumption: head can only be a Literal - head = self._reify_ast(node.head) + head = self.reify_node(node.head) rule = preds.Rule(id=rule1.id, head=head, body=preds.Literal_Tuple1()) self._reified.add(rule) statement_tup = preds.Statement_Tuple( @@ -276,28 +298,28 @@ def _reify_rule(self, node): for pos, element in enumerate(node.body, start=0): self._reified.add( preds.Literal_Tuple( - id=rule.body.id, position=pos, element=self._reify_ast(element) + id=rule.body.id, position=pos, element=self.reify_node(element) ) ) return - @_reify_ast.register(ASTType.Literal) + @reify_node.register(ASTType.Literal) def _reify_literal(self, node): # assumption: all literals contain only symbolic atoms lit1 = preds.Literal1() clorm_sign = preds.sign_ast2cl[node.sign] - lit = preds.Literal(id=lit1.id, sig=clorm_sign, atom=self._reify_ast(node.atom)) + lit = preds.Literal(id=lit1.id, sig=clorm_sign, atom=self.reify_node(node.atom)) self._reified.add(lit) return lit1 - @_reify_ast.register(ASTType.SymbolicAtom) + @reify_node.register(ASTType.SymbolicAtom) def _reify_symbolic_atom(self, node): atom1 = preds.Atom1() - atom = preds.Atom(id=atom1.id, symbol=self._reify_ast(node.symbol)) + atom = preds.Atom(id=atom1.id, symbol=self.reify_node(node.symbol)) self._reified.add(atom) return atom1 - @_reify_ast.register(ASTType.Function) + @reify_node.register(ASTType.Function) def _reify_function(self, node): """Reify an ast node with node.ast_type of ASTType.Function. @@ -319,18 +341,18 @@ def _reify_function(self, node): preds.Term_Tuple( id=function.arguments.id, position=pos, - element=self._reify_ast(term), + element=self.reify_node(term), ) ) return function1 - @_reify_ast.register(ASTType.Variable) + @reify_node.register(ASTType.Variable) def _reify_variable(self, node): variable1 = preds.Variable1() self._reified.add(preds.Variable(id=variable1.id, name=node.name)) return variable1 - @_reify_ast.register(ASTType.SymbolicTerm) + @reify_node.register(ASTType.SymbolicTerm) def _reify_symbolic_term(self, node): """Reify symbolic term. @@ -339,28 +361,28 @@ def _reify_symbolic_term(self, node): don't represent this ast node in our reification. """ - return self._reify_ast(node.symbol) + return self.reify_node(node.symbol) - @_reify_ast.register(ASTType.BinaryOperation) + @reify_node.register(ASTType.BinaryOperation) def _reify_binary_operation(self, node): clorm_operator = preds.binary_operator_ast2cl[node.operator_type] binop1 = preds.Binary_Operation1() binop = preds.Binary_Operation( id=binop1.id, operator=clorm_operator, - left=self._reify_ast(node.left), - right=self._reify_ast(node.right), + left=self.reify_node(node.left), + right=self.reify_node(node.right), ) self._reified.add(binop) return binop1 - @_reify_ast.register(SymbolType.Number) + @reify_node.register(SymbolType.Number) def _reify_symbol_number(self, symb): number1 = preds.Number1() self._reified.add(preds.Number(id=number1.id, value=symb.number)) return number1 - @_reify_ast.register(SymbolType.Function) + @reify_node.register(SymbolType.Function) def _reify_symbol_function(self, symb): """Reify constant term. @@ -374,78 +396,82 @@ def _reify_symbol_function(self, symb): ) return func1 - @_reify_ast.register(SymbolType.String) + @reify_node.register(SymbolType.String) def _reify_symbol_string(self, symb): string1 = preds.String1() self._reified.add(preds.String(id=string1.id, value=symb.string)) return string1 - def _reflect_child_pred(self, parent_pred, child_id_pred): + def _reflect_child_pred(self, parent_fact, child_id_fact): """Utility function that takes a unary ast predicate containing only an identifier pointing to a child predicate, queries reified factbase for child predicate, and returns the child node obtained by reflecting the child predicate. """ - identifier = child_id_pred.id - child_ast_pred = preds.id_pred2ast_pred[type(child_id_pred)] - query = self._reified.query(child_ast_pred).where( - child_ast_pred.id == identifier + identifier = child_id_fact.id + child_ast_pred = preds.id_pred2ast_pred[type(child_id_fact)] + query = (self._reified.query(child_ast_pred) + .where(child_ast_pred.id == identifier) ) child_preds = list(query.all()) if len(child_preds) == 0: msg = ( - f"Error finding child fact of predicate:\n{parent_pred}:\n" - f"Expected single child fact for identifier {child_id_pred}" + f"Error finding child fact of predicate '{parent_fact}':\n" + f"Expected single child fact for identifier '{child_id_fact}'" ", found none." ) raise ChildQueryError(msg) elif len(child_preds) > 1: child_pred_strings = [str(pred) for pred in child_preds] msg = ( - f"Error finding child fact of predicate:\n{parent_pred}:\n" - f"Expected single child fact for identifier {child_id_pred}" + f"Error finding child fact of predicate '{parent_fact}':\n" + f"Expected single child fact for identifier '{child_id_fact}'" ", found multiple:\n" + "\n".join(child_pred_strings) ) raise ChildQueryError(msg) else: child_pred = child_preds[0] - return self._reflect_predicate(child_pred) + return self.reflect_predicate(child_pred) - def _reflect_child_preds(self, parent_pred, id_predicate): - """Utility function that takes a unary ast predicate - containing only an identifier pointing to a tuple of child - predicates, and returns a list of the child nodes obtained by - reflecting all child predicates in order. + def _reflect_child_preds(self, parent_fact, children_id_fact): + """Utility function that takes a unary ast fact containing + only an identifier pointing to a tuple of child facts, and + returns a list of the child nodes obtained by reflecting all + child facts. """ - identifier = id_predicate.id - ast_fact_pred = getattr(preds, type(id_predicate).__name__.rstrip("1")) + identifier = children_id_fact.id + child_pred = preds.id_pred2ast_pred[type(children_id_fact)] query = ( - self._reified.query(ast_fact_pred) - .where(ast_fact_pred.id == identifier) - .order_by(ast_fact_pred.position) + self._reified.query(child_pred) + .where(child_pred.id == identifier) + .order_by(child_pred.position) ) tuples = list(query.all()) - child_nodes = list() + child_nodes = [] for tup in tuples: child_nodes.append(self._reflect_child_pred(tup, tup.element)) return child_nodes @singledispatchmethod - def _reflect_predicate(self, pred: preds.AST_Predicate): # nocoverage - """Convert the input AST element's reified clorm predicate - representation back into a the corresponding memer of clingo's - abstract syntax tree. + def reflect_predicate(self, pred: preds.AST_Predicate): # nocoverage + """Convert the input AST element's reified fact representation + back into a the corresponding member of clingo's abstract + syntax tree, recursively reflecting all child facts. """ raise NotImplementedError( - f"reflection not implemented for predicate of type {type(pred)}." + f"Reflection not implemented for predicate of type {type(pred)}." ) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_program(self, program: preds.Program) -> Sequence[AST]: - subprogram = list() + """Reflect a (sub)program fact into sequence of AST nodes, one + node per each statement in the (sub)program. + + """ + subprogram = [] parameter_nodes = self._reflect_child_preds(program, program.parameters) subprogram.append( ast.Program( @@ -456,8 +482,9 @@ def _reflect_program(self, program: preds.Program) -> Sequence[AST]: subprogram.extend(statement_nodes) return subprogram - @_reflect_predicate.register + @reflect_predicate.register def _reflect_external(self, external: preds.External) -> AST: + """Reflect an External fact into an External node.""" atom_node = self._reflect_child_pred(external, external.atom) body_nodes = self._reflect_child_preds(external, external.body) ext_type = ast.SymbolicTerm( @@ -468,49 +495,64 @@ def _reflect_external(self, external: preds.External) -> AST: location=DUMMY_LOC, atom=atom_node, body=body_nodes, external_type=ext_type ) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_rule(self, rule: preds.Rule) -> AST: + """Reflect a Rule fact into a Rule node.""" head_node = self._reflect_child_pred(rule, rule.head) body_nodes = self._reflect_child_preds(rule, rule.body) return ast.Rule(location=DUMMY_LOC, head=head_node, body=body_nodes) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_literal(self, lit: preds.Literal) -> AST: + """Reflect a Literal fact into a Literal node.""" sign = preds.sign_cl2ast[lit.sig] atom_node = self._reflect_child_pred(lit, lit.atom) return ast.Literal(location=DUMMY_LOC, sign=sign, atom=atom_node) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_atom(self, atom: preds.Atom) -> AST: + """Reflect an Atom fact into an Atom node.""" return ast.SymbolicAtom(symbol=self._reflect_child_pred(atom, atom.symbol)) - @_reflect_predicate.register - def _reflect_function(self, func: preds.Function) -> Union[AST, Symbol]: - """Reflect function, which may represent a propositional - constant, predicate, function symbol, or constant term""" + @reflect_predicate.register + def _reflect_function(self, func: preds.Function) -> AST: + """Reflect a Function fact into a Function node. + + Note that a Function fact is used to represent a propositional + constant, predicate, function symbol, or constant term. All of + these can be validly represented by a Function node in the + clingo AST and so we can return a Function node in each case. + Constant terms are parsed a Symbol by the parser, thus we need + to handle them differently when reifying. + + """ arg_nodes = self._reflect_child_preds(func, func.arguments) return ast.Function( location=DUMMY_LOC, name=func.name, arguments=arg_nodes, external=0 ) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_variable(self, var: preds.Variable) -> AST: + """Reflect a Variable fact into a Variable node.""" return ast.Variable(location=DUMMY_LOC, name=var.name) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_number(self, number: preds.Number) -> AST: + """Reflect a Number fact into a Number node.""" return ast.SymbolicTerm( location=DUMMY_LOC, symbol=symbol.Number(number=number.value) ) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_string(self, string: preds.String) -> Symbol: + """Reflect a String fact into a String node.""" return ast.SymbolicTerm( location=DUMMY_LOC, symbol=symbol.String(string=string.value) ) - @_reflect_predicate.register + @reflect_predicate.register def _reflect_binary_operation(self, operation: preds.Binary_Operation) -> AST: + """Reflect a Binary_Operation fact into a BinaryOperation node.""" ast_operator = preds.binary_operator_cl2ast[operation.operator] return ast.BinaryOperation( location=DUMMY_LOC, @@ -523,11 +565,11 @@ def reflect(self): """Convert the reified ast contained in the internal factbase back into a non-ground program.""" # reset list of program statements before population via reflect - self._program_statements = list() + self._program_ast = [] # should probably define an order in which programs are queried for prog in self._reified.query(preds.Program).all(): - subprogram = self._reflect_predicate(prog) - self._program_statements.extend(subprogram) + subprogram = self.reflect_predicate(prog) + self._program_ast.extend(subprogram) logger.info(f"Reflected program string:\n{self.program_string}") return @@ -539,11 +581,11 @@ def transform( """Transform the reified AST using meta encoding. Parameter meta_prog may be a string path to file containing - meta-encoding, or just the meta-encoding program strin itself. + meta-encoding, or the meta-encoding in string form. """ if len(self._reified) == 0: - logger.warn("Reified AST to be transformed is empty.") + logger.warning("Reified AST to be transformed is empty.") if meta_str is None and meta_files is None: raise ValueError("No meta-program provided for transformation.") meta_prog = "" @@ -554,7 +596,7 @@ def transform( with meta_file.open() as f: meta_prog += f.read() - ctl = Control() + ctl = Control(["--warn=none"]) control_add_facts(ctl, self._reified) ctl.add(meta_prog) ctl.load("./src/renopro/asp/encodings/transform.lp") @@ -563,11 +605,9 @@ def transform( model = next(iter(handle)) ast_symbols = [final.arguments[0] for final in model.symbols(shown=True)] unifier = Unifier(preds.AST_Facts) - try: + with TryUnify(): ast_facts = unifier.iter_unify(ast_symbols, raise_nomatch=True) self._reified = FactBase(ast_facts) - except UnifierNoMatchError as e: - raise_unmatched_ast_fact(e) if __name__ == "__main__": # nocoverage diff --git a/tests/test_reify_reflect.py b/tests/test_reify_reflect.py index e3f07d4..47abdd4 100644 --- a/tests/test_reify_reflect.py +++ b/tests/test_reify_reflect.py @@ -1,45 +1,111 @@ """Test cases for reification functionality.""" from itertools import count from pathlib import Path +from typing import List from unittest import TestCase +import inspect +import re -from clorm import FactBase, parse_fact_files +from clorm import FactBase, Predicate, UnifierNoMatchError +from clingo.ast import parse_string import renopro.predicates as preds -from renopro.reify import ReifiedAST +from renopro.reify import ReifiedAST, ChildQueryError, ChildrenQueryError -test_reify_files = Path("src", "renopro", "asp", "tests", "reify") +reify_files = Path("src", "renopro", "asp", "tests", "reify") +good_reify_files = reify_files / "good_ast" +malformed_reify_files = reify_files / "malformed_ast" -class TestReifyReflect(TestCase): +class TestReifiedAST(TestCase): + def setUp(self): + # reset id generator between test cases so reification + # auto-generates the expected integers + preds.id_count = count() + + +class TestReifiedASTInterface(TestReifiedAST): + """Test interfaces of ReifiedAST class""" + + class NotAnASTFact(Predicate): + pass + + def test_update_ast_facts(self): + """Test updating of reified facts of a ReifiedAST by list of + ast facts.""" + rast = ReifiedAST() + ast_facts = [preds.Variable(id=0, name="X")] + rast.add_reified_facts(ast_facts) + self.assertSetEqual(rast.reified_facts, FactBase(ast_facts)) + rast = ReifiedAST() + with self.assertRaises(UnifierNoMatchError): + rast.add_reified_facts([self.NotAnASTFact()]) + + def test_add_reified_string(self): + """Test adding string representation of reified facts to + reified facts of a ReifiedAST.""" + rast = ReifiedAST() + fact = 'variable(0,"X").\n' + rast.add_reified_string(fact) + fb = FactBase([preds.Variable(id=0, name="X")]) + self.assertSetEqual(rast.reified_facts, fb) + self.assertEqual(rast.reified_string, fact) + rast = ReifiedAST() + with self.assertRaises(UnifierNoMatchError): + rast.add_reified_string('variance(0,"X").') + + def test_unification_error_message(self): + """Test that ReifiedAST rejects adding of facts that do not + unify against any ast predicate definition, with an + informative error message. + + """ + rast = ReifiedAST() + # first case: signature does not match any ast facts + # should show the closest matching signature in the error message. + fact_str = 'litteral(1,"pos",atom(2)).' + regex = r"(?s).*'litteral\(1,\"pos\",atom\(2\)\)'\..*literal/3." + with self.assertRaisesRegex(UnifierNoMatchError, expected_regex=regex): + rast.add_reified_string(fact_str) + # second case: argument of ast fact fails to unify + # should show the argument that failed to unify with the specified field + fact_str = "literal(1,\"pos\",attom(2))." + regex = r"(?s).*'literal\(1,\"pos\",attom\(2\)\)'.*'attom\(2\).*Atom1Field" + with self.assertRaisesRegex(UnifierNoMatchError, expected_regex=regex): + rast.add_reified_string(fact_str) + + def test_reified_files(self): + """Test adding of reified facts from files to reified facts of a ReifiedAST.""" + rast = ReifiedAST() + rast.add_reified_files([malformed_reify_files / "ast_fact.lp"]) + fb = FactBase([preds.Atom(id=12, symbol=preds.Function1(id=13))]) + self.assertSetEqual(rast.reified_facts, fb) + + +class TestReifyReflect(TestReifiedAST): """Base class for tests for reification and reflection of non-ground programs.""" - default_ast_facts = [] base_str = "" - def get_test_facts(self, fact_file_str: str): - """Parse fact file from test directory.""" - facts = parse_fact_files( - [str(test_reify_files / fact_file_str)], unifier=preds.AST_Facts - ) - return facts - - def assertReifyReflectEqual(self, prog_str: str, ast_facts: FactBase): + def assertReifyReflectEqual(self, prog_str: str, ast_fact_files: List[str]): """Assert that reification of prog_str results in ast_facts, and that reflection of ast_facts result in prog_str.""" - ast_facts.add(self.default_ast_facts) - # reset id generator counter so reification generates expected integers - preds.id_count = count() + + ast_fact_files = [(good_reify_files / f) for f in ast_fact_files] for operation in ["reification", "reflection"]: with self.subTest(operation=operation): if operation == "reification": - rast = ReifiedAST() - rast.reify_string(prog_str) - self.assertSetEqual(rast._reified, ast_facts) + rast1 = ReifiedAST() + rast1.reify_string(prog_str) + reified_facts = rast1.reified_facts + rast2 = ReifiedAST() + rast2.add_reified_files(ast_fact_files) + expected_facts = rast2.reified_facts + self.assertSetEqual(reified_facts, expected_facts) elif operation == "reflection": rast = ReifiedAST() - rast.add_reified_facts(ast_facts) + rast.add_reified_files(ast_fact_files) rast.reflect() expected_string = self.base_str + prog_str self.assertEqual(rast.program_string, expected_string) @@ -55,69 +121,90 @@ def setUp(self): # reset id counter between test cases preds.id_count = count() - def test_add_ast_facts(self): - rast = ReifiedAST() - ast_facts = [preds.Variable(id=0, name="X")] - rast.add_reified_facts(ast_facts) - self.assertEqual(rast._reified, FactBase(ast_facts)) - - def test_reify_program_prop_fact(self): + def test_reify_prop_fact(self): """Test reification of a propositional fact.""" - prog_str = "a." - facts = self.get_test_facts("prop_fact.lp") - self.assertReifyReflectEqual(prog_str, facts) + self.assertReifyReflectEqual("a.", ["prop_fact.lp"]) + rast = ReifiedAST() + rast.reify_string("a.") + statements = [] + parse_string("a.", lambda s: statements.append(s)) + self.assertEqual(rast.program_ast, statements) - def test_reify_program_prop_normal_rule(self): + def test_reify_prop_normal_rule(self): """ Test reification of a normal rule containing only propositional atoms. """ - prog_str = "a :- b; not c." - facts = self.get_test_facts("prop_normal_rule.lp") - self.assertReifyReflectEqual(prog_str, facts) + self.assertReifyReflectEqual("a :- b; not c.", ["prop_normal_rule.lp"]) - def test_reify_program_function(self): + def test_reify_function(self): """ Test reification of a variable-free normal rule with function symbols. """ - prog_str = "rel(2,1) :- rel(1,2)." - facts = self.get_test_facts("function.lp") - self.assertReifyReflectEqual(prog_str, facts) + self.assertReifyReflectEqual("rel(2,1) :- rel(1,2).", ["function.lp"]) - def test_reify_program_nested_function(self): - prog_str = "next(move(a))." - facts = self.get_test_facts("nested_function.lp") - self.assertReifyReflectEqual(prog_str, facts) + def test_reify_nested_function(self): + self.assertReifyReflectEqual("next(move(a)).", ["nested_function.lp"]) - def test_reify_program_variable(self): + def test_reify_variable(self): """ Test reification of normal rule with variables. """ - prog_str = "rel(Y,X) :- rel(X,Y)." - facts = self.get_test_facts("variable.lp") - self.assertReifyReflectEqual(prog_str, facts) + self.assertReifyReflectEqual("rel(Y,X) :- rel(X,Y).", ["variable.lp"]) - def test_reify_program_string(self): + def test_reify_string(self): """ Test reification of normal rule with string. """ - prog_str = 'yummy("carrot").' - facts = self.get_test_facts("string.lp") - self.assertReifyReflectEqual(prog_str, facts) + self.assertReifyReflectEqual('yummy("carrot").', ["string.lp"]) - def test_reify_program_constant_term(self): + def test_reify_constant_term(self): """ Test reification of normal rule with constant term. """ - prog_str = "good(human)." - facts = self.get_test_facts("constant_term.lp") - self.assertReifyReflectEqual(prog_str, facts) + self.assertReifyReflectEqual("good(human).", ["constant_term.lp"]) - def test_reify_program_binary_operator(self): - prog_str = "equal((1+1),2)." - facts = self.get_test_facts("binary_operation.lp") - self.assertReifyReflectEqual(prog_str, facts) + def test_reify_binary_operator(self): + self.assertReifyReflectEqual("equal((1+1),2).", ["binary_operation.lp"]) def test_reify_external_false(self): - prog_str = "#external a(X) : c(X); d(e(X)). [false]" - facts = self.get_test_facts("external.lp") - self.assertReifyReflectEqual(prog_str, facts) + """Test reification of an external statement with default value false.""" + self.assertReifyReflectEqual( + "#external a(X) : c(X); d(e(X)). [false]", ["external.lp"] + ) + + def test_reify_program_params(self): + """Test reification of a program statement with parameters.""" + self.assertReifyReflectEqual("#program acid(k).", + ["program_acid.lp"]) + + def test_reify_node_failure(self): + """Reification for any object not of type clingo.ast.AST or + clingo.symbol.Symbol should raise a TypeError.""" + rast = ReifiedAST() + not_node = {"not": "ast"} + regex = "(?s).*AST or Symbol.*dict" + with self.assertRaisesRegex(TypeError, expected_regex=regex): + rast.reify_node(not_node) + + def test_child_query_error_none_found(self): + """Reflection of a parent fact that expects a singe child fact + but finds none should fail with an informative error message. + + """ + rast = ReifiedAST() + rast.add_reified_files([malformed_reify_files / "missing_child.lp"]) + regex = r"(?s).*atom\(4,function\(5\)\).*function\(5\).*found none.*" + with self.assertRaisesRegex(ChildQueryError, expected_regex=regex): + rast.reflect() + + def test_child_query_error_multiple_found(self): + """Reflection of a parent fact that expects a singe child fact + but finds multiple should fail with an informative error + message. + + """ + rast = ReifiedAST() + rast.add_reified_files([malformed_reify_files / "multiple_child.lp"]) + regex = r"(?s).*atom\(4,function\(5\)\).*function\(5\).*found multiple.*" + with self.assertRaisesRegex(ChildQueryError, expected_regex=regex): + rast.reflect() diff --git a/tests/test_transform.py b/tests/test_transform.py index e38a2e4..ca7c574 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -42,3 +42,13 @@ def test_transform_add_time(self): with (files_dir / (testname + "_output.lp")).open("r") as output: expected_str = self.base_str + output.read().strip() self.assertEqual(transformed_str, expected_str) + + def test_transform_bad_input(self): + """Test transform behavior under bad input.""" + rast = ReifiedAST() + # should log warning if rast has no facts before transformation + with self.assertLogs("renopro.reify", level="WARNING"): + rast.transform(meta_str="") + # should raise error if no meta program is provided + with self.assertRaises(ValueError): + rast.transform()