diff --git a/src/renopro/asp/ast.lp b/src/renopro/asp/ast.lp index ae5c2fc..266516a 100644 --- a/src/renopro/asp/ast.lp +++ b/src/renopro/asp/ast.lp @@ -33,6 +33,7 @@ ast(body_literal(Id,Sign,Atom)) :- body_literal(Id,Sign,Atom). ast(body_literals(Id,Pos,Element)) :- body_literals(Id,Pos,Element). ast(head_agg_elements(Id,Pos,Terms,Condition)) :- head_agg_elements(Id,Pos,Terms,Condition). ast(head_aggregate(Id,LGuard,Elements,RGuard)) :- head_aggregate(Id,LGuard,Elements,RGuard). +ast(conditional_literals(Id,Pos,CondLit)) :- conditional_literals(Id,Pos,CondLit). ast(disjunction(Id,Pos,Element)) :- disjunction(Id,Pos,Element). ast(rule(Id,Head,Body)) :- rule(Id,Head,Body). ast(statements(Id,Pos,Element)) :- statements(Id,Pos,Element). diff --git a/src/renopro/asp/defined.lp b/src/renopro/asp/defined.lp index 8381210..a804a00 100644 --- a/src/renopro/asp/defined.lp +++ b/src/renopro/asp/defined.lp @@ -32,6 +32,7 @@ #defined body_literals/3. #defined head_agg_elements/4. #defined head_aggregate/4. +#defined conditional_literals/3. #defined disjunction/3. #defined rule/3. #defined statements/3. diff --git a/src/renopro/predicates.py b/src/renopro/predicates.py index 64a20a2..73e5b57 100644 --- a/src/renopro/predicates.py +++ b/src/renopro/predicates.py @@ -483,12 +483,10 @@ class Theory_Operators1(ComplexTerm, name="theory_operators"): id = Identifier_Field(default=lambda: next(id_count)) -class Theory_Unparsed_Term(Predicate): - """An unparsed theory term is a tuple, each element of which - consists of a tuple of theory operators and a theory term. This - predicate represents an element of an unparsed theory term. +class Theory_Unparsed_Term_Elements(Predicate): + """Predicate representing an element of an unparsed theory term. - id: The identifier of the unparsed theory term. + id: Identifier of the tuple of elements. position: Integer representing position of the element of the theory tuple, ordered by <. operators: A tuple of theory operators. @@ -501,6 +499,26 @@ class Theory_Unparsed_Term(Predicate): term = TheoryTermField +class Theory_Unparsed_Term_Elements1(ComplexTerm, name="theory_unparsed_term_elements"): + "Term identifying a child tuple element of an unparsed theory term." + id = Identifier_Field(default=lambda: next(id_count)) + + +class Theory_Unparsed_Term(Predicate): + """Predicate representing an unparsed theory term. + An unparsed theory term consists of a tuple, each element of which + consists of a tuple of theory operators and a theory term. This + predicate represents an element of an unparsed theory term. + + id: The identifier of the unparsed theory term. + elements: The tuple of aforementioned elements + forming the unparsed theory term. + """ + + id = Identifier_Field(default=lambda: next(id_count)) + elements = Theory_Unparsed_Term_Elements1.Field + + class Theory_Unparsed_Term1(ComplexTerm, name="theory_unparsed_term"): "Term identifying a child unparsed theory term fact." id = Identifier_Field(default=lambda: next(id_count)) @@ -628,7 +646,7 @@ class Literal1(ComplexTerm, name="literal"): class Literals(Predicate): - """Predicate representing an element of a tuple of (conditional) literals. + """Predicate representing an element of a tuple of literals. id: Identifier of the tuple. position: Integer representing position of the element the tuple, ordered by <. @@ -636,7 +654,7 @@ class Literals(Predicate): """ id = Identifier_Field(default=lambda: next(id_count)) - position = IntegerField # should we keep track of position? + position = IntegerField literal = Literal1.Field @@ -913,20 +931,35 @@ class Head_Aggregate1(Predicate, name="head_aggregate"): id = Identifier_Field(default=lambda: next(id_count)) +class Conditional_Literals(Predicate): + """Predicate representing an element of a tuple of conditional literals. + + id: Identifier of the tuple of conditional literals. + position: Integer representing position of the element the tuple, ordered by <. + conditional_literal: Term identifying the conditional literal element. + """ + + id = Identifier_Field(default=lambda: next(id_count)) + position = IntegerField + conditional_literal = Conditional_Literal1.Field + + +class Conditional_Literals1(ComplexTerm, name="conditional_literals"): + "Term identifying a child tuple of conditional literals." + id = Identifier_Field(default=lambda: next(id_count)) + + class Disjunction(Predicate): """Predicate representing a disjunction of (conditional) literals. id: Identifier of the disjunction. - position: Integer representing position of the element the disjunction, - ordered by <. - element: The element of the disjunction, a conditional literal. + elements: The elements of the disjunction, a tuple of conditional literals. A literal in a disjunction is represented as a conditional literal with an empty condition. """ id = Identifier_Field(default=lambda: next(id_count)) - position = IntegerField - conditional_literal = Conditional_Literal1.Field + elements = Conditional_Literals1.Field class Disjunction1(ComplexTerm, name="disjunction"): @@ -1057,6 +1090,7 @@ class Program(Predicate): Theory_Sequence, Theory_Function, Theory_Operators, + Theory_Unparsed_Term_Elements, Theory_Unparsed_Term, Guard, Guards, @@ -1077,6 +1111,7 @@ class Program(Predicate): Body_Literals, Head_Agg_Elements, Head_Aggregate, + Conditional_Literals, Disjunction, Rule, Statements, @@ -1099,6 +1134,7 @@ class Program(Predicate): Theory_Sequence, Theory_Function, Theory_Operators, + Theory_Unparsed_Term_Elements, Theory_Unparsed_Term, Guard, Guards, @@ -1119,6 +1155,7 @@ class Program(Predicate): Body_Literals, Head_Agg_Elements, Head_Aggregate, + Conditional_Literals, Disjunction, Rule, Statements, diff --git a/src/renopro/rast.py b/src/renopro/rast.py index 4100f0b..7bed939 100644 --- a/src/renopro/rast.py +++ b/src/renopro/rast.py @@ -1,13 +1,12 @@ # pylint: disable=too-many-lines """Module implementing reification and de-reification of non-ground programs""" -import enum import inspect import logging import re from contextlib import AbstractContextManager from functools import singledispatchmethod from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Type, Union +from typing import Iterator, List, Literal, Optional, Sequence, Type, Union, overload from clingo import Control, ast, symbol from clingo.ast import ( @@ -33,6 +32,7 @@ from thefuzz import process # type: ignore import renopro.predicates as preds +from renopro.utils import assert_never from renopro.utils.logger import get_clingo_logger_callback logger = logging.getLogger(__name__) @@ -260,7 +260,7 @@ def reify_node(self, node): raise TypeError(f"Nodes should be of type AST or Symbol, got: {type(node)}") def _reify_ast_seqence( - self, seq: ASTSequence, tup_id: BaseField, tup_pred: Type[Predicate] + self, seq: ASTSequence, tup_id: BaseField, tup_pred: Type[preds.AstPredicate] ): """Reify ast sequence into a tuple of predicates of type tup_pred with identifier tup_id.""" @@ -407,8 +407,9 @@ def _reify_theory_function(self, node): return theory_func1 @reify_node.register(ASTType.TheoryUnparsedTerm) - def _reify_unparsed_theory_term(self, node): + def _reify_theory_unparsed_term(self, node): reified_unparsed_theory_term1 = preds.Theory_Unparsed_Term1() + reified_unparsed_elements1 = preds.Theory_Unparsed_Term_Elements1() # reified_unparsed_theory_term = preds.Theory_Unparsed_Term( # id=reified_unparsed_theory_term1.id, # elements=preds.Theory_Unparsed_Term_Elements1(), @@ -421,13 +422,17 @@ def _reify_unparsed_theory_term(self, node): ] self._reified.add(reified_operators) reified_theory_term1 = self.reify_node(element.term) - reified_unparsed_theory_term = preds.Theory_Unparsed_Term( - id=reified_unparsed_theory_term1.id, + reified_unparsed_elements = preds.Theory_Unparsed_Term_Elements( + id=reified_unparsed_elements1.id, position=pos, operators=operators, term=reified_theory_term1, ) - self._reified.add(reified_unparsed_theory_term) + self._reified.add(reified_unparsed_elements) + reified_unparsed = preds.Theory_Unparsed_Term( + id=reified_unparsed_theory_term1.id, elements=reified_unparsed_elements1 + ) + self._reified.add(reified_unparsed) return reified_unparsed_theory_term1 @reify_node.register(ASTType.Guard) @@ -503,9 +508,11 @@ def _reify_aggregate(self, node) -> preds.Aggregate1: ) elements1 = preds.Agg_Elements1() self._reify_ast_seqence(node.elements, elements1.id, preds.Agg_Elements) - right_guard = preds.Guard1() - if node.right_guard is not None: - self.reify_node(node.right_guard) + right_guard = ( + preds.Guard1() + if node.right_guard is None + else self.reify_node(node.right_guard) + ) count_agg = preds.Aggregate( id=count_agg1.id, left_guard=left_guard, @@ -593,6 +600,30 @@ def _reify_body_aggregate(self, node) -> preds.Body_Aggregate1: self._reified.add(agg) return agg1 + def _reify_body_literals(self, body_lits: Sequence[ast.AST], body_id): + reified_body_lits = [] + for pos, lit in enumerate(body_lits, start=0): + if lit.ast_type is ast.ASTType.ConditionalLiteral: + cond_lit1 = self.reify_node(lit) + reified_body_lits.append( + preds.Body_Literals( + id=body_id, position=pos, body_literal=cond_lit1 + ) + ) + else: + body_lit1 = preds.Body_Literal1() + reified_body_lits.append( + preds.Body_Literals( + id=body_id, position=pos, body_literal=body_lit1 + ) + ) + clorm_sign = preds.convert_enum(ast.Sign(lit.sign), preds.Sign) + body_lit = preds.Body_Literal( + id=body_lit1.id, sig=clorm_sign, atom=self.reify_node(lit.atom) + ) + self._reified.add(body_lit) + self._reified.add(reified_body_lits) + @reify_node.register(ASTType.HeadAggregate) def _reify_head_aggregate(self, node) -> preds.Head_Aggregate1: agg1 = preds.Head_Aggregate1() @@ -630,37 +661,25 @@ def _reify_head_aggregate(self, node) -> preds.Head_Aggregate1: self._reified.add(agg) return agg1 - def _reify_body_literals(self, body_lits: Sequence[ast.AST], body_id): - reified_body_lits = [] - for pos, lit in enumerate(body_lits, start=0): - if lit.ast_type is ast.ASTType.ConditionalLiteral: - cond_lit1 = self.reify_node(lit) - reified_body_lits.append( - preds.Body_Literals(id=body_id, position=pos, body_literal=cond_lit1) - ) - else: - body_lit1 = preds.Body_Literal1() - reified_body_lits.append( - preds.Body_Literals(id=body_id, position=pos, body_literal=body_lit1) - ) - clorm_sign = preds.convert_enum(ast.Sign(lit.sign), preds.Sign) - body_lit = preds.Body_Literal( - id=body_lit1.id, sig=clorm_sign, atom=self.reify_node(lit.atom) - ) - self._reified.add(body_lit) - self._reified.add(reified_body_lits) + @reify_node.register(ASTType.Disjunction) + def _reify_disjunction(self, node) -> preds.Disjunction1: + disj1 = preds.Disjunction1() + cond_lits1 = preds.Conditional_Literals1() + self._reify_ast_seqence( + node.elements, cond_lits1.id, preds.Conditional_Literals + ) + disj = preds.Disjunction(id=disj1.id, elements=cond_lits1) + self._reified.add(disj) + return disj1 @reify_node.register(ASTType.Rule) def _reify_rule(self, node): rule1 = preds.Rule1() - if node.head.ast_type is ASTType.Disjunction: - head = preds.Disjunction1() - self._reify_ast_seqence(node.head.elements, head.id, preds.Disjunction) - else: - head = self.reify_node(node.head) + head = self.reify_node(node.head) rule = preds.Rule(id=rule1.id, head=head, body=preds.Body_Literals1()) self._reified.add(rule) - statement_tup = preds.Statements(id=self._statement_tup_id, position=self._statement_pos, statement=rule1 + statement_tup = preds.Statements( + id=self._statement_tup_id, position=self._statement_pos, statement=rule1 ) self._reified.add(statement_tup) self._statement_pos += 1 @@ -674,7 +693,7 @@ def _reify_program(self, node): id=const_tup_id.id, position=pos, constant=preds.Function1() ) const = preds.Function( - id=const_tup.element.id, name=param.name, arguments=preds.Terms1() + id=const_tup.constant.id, name=param.name, arguments=preds.Terms1() ) self._reified.add([const_tup, const]) program = preds.Program( @@ -702,18 +721,46 @@ def _reify_external(self, node): self._statement_pos += 1 self._reify_body_literals(node.body, external.body.id) - class ExpectedNumberOfChilden(str, enum.Enum): - "String symbols representing expected number of child facts." - One = "1" - ZeroOrOne = "?" - ZeroOrMore = "*" - OneOrMore = "+" + ExpectedNum = Literal["1", "?", "+", "*"] - def reflect_child( + @overload + def _reflect_child( self, parent_fact: preds.AstPredicate, child_id_fact, - expected_children_num: ExpectedNumberOfChilden = "1", + expected_children_num: Literal["1"], + ) -> AST: # nocoverage + ... + + # for handling the default argument "1" + + @overload + def _reflect_child(self, parent_fact: preds.AstPredicate, child_id_fact) -> AST: #nocoverage + ... + + @overload + def _reflect_child( + self, + parent_fact: preds.AstPredicate, + child_id_fact, + expected_children_num: Literal["?"], + ) -> Optional[AST]: # nocoverage + ... + + @overload + def _reflect_child( + self, + parent_fact: preds.AstPredicate, + child_id_fact, + expected_children_num: Literal["*", "+"], + ) -> Sequence[AST]: # nocoverage + ... + + def _reflect_child( + self, + parent_fact: preds.AstPredicate, + child_id_fact, + expected_children_num: ExpectedNum = "1", ) -> Union[None, AST, Sequence[AST]]: """Utility function that takes a unary ast predicate identifying a child predicate, queries reified factbase for @@ -751,7 +798,7 @@ def reflect_child( f"'{child_id_fact}', found {num_child_facts}." ) raise ChildQueryError(base_msg + msg) - elif expected_children_num in ["*", "+"]: + elif expected_children_num == "*" or expected_children_num == "+": query = query.order_by(child_ast_pred.position) child_facts = list(query.all()) num_child_facts = len(child_facts) @@ -767,14 +814,10 @@ def reflect_child( f"Expected 1 or more child facts for identifier " f"'{child_id_fact}', found 0." ) - raise ChildQueryError(base_msg + msg) + raise ChildrenQueryError(base_msg + msg) child_nodes = [self.reflect_predicate(fact) for fact in child_facts] return child_nodes - else: - raise RuntimeError( - "Invalid expected_children_num argument. Value must be one of " - f"{[e.value for e in self.ExpectedNumberOfChilden]}." - ) # nocoverage + assert_never(expected_children_num) @singledispatchmethod def reflect_predicate(self, pred: preds.AstPredicate): # nocoverage @@ -816,7 +859,7 @@ def _reflect_unary_operation(self, operation: preds.Unary_Operation) -> AST: return ast.UnaryOperation( location=DUMMY_LOC, operator_type=clingo_operator, - argument=self.reflect_child(operation, operation.argument), + argument=self._reflect_child(operation, operation.argument), ) @reflect_predicate.register @@ -825,8 +868,8 @@ def _reflect_binary_operation(self, operation: preds.Binary_Operation) -> AST: clingo_operator = preds.convert_enum( preds.BinaryOperator(operation.operator), ast.BinaryOperator ) - reflected_left = self.reflect_child(operation, operation.left) - reflected_right = self.reflect_child(operation, operation.right) + reflected_left = self._reflect_child(operation, operation.left) + reflected_right = self._reflect_child(operation, operation.right) return ast.BinaryOperation( location=DUMMY_LOC, operator_type=clingo_operator, @@ -836,15 +879,15 @@ def _reflect_binary_operation(self, operation: preds.Binary_Operation) -> AST: @reflect_predicate.register def _reflect_interval(self, interval: preds.Interval) -> AST: - reflected_left = self.reflect_child(interval, interval.left) - reflected_right = self.reflect_child(interval, interval.right) + reflected_left = self._reflect_child(interval, interval.left) + reflected_right = self._reflect_child(interval, interval.right) return ast.Interval( location=DUMMY_LOC, left=reflected_left, right=reflected_right ) @reflect_predicate.register def _reflect_terms(self, terms: preds.Terms) -> AST: - return self.reflect_child(terms, terms.term) + return self._reflect_child(terms, terms.term) @reflect_predicate.register def _reflect_function(self, func: preds.Function) -> AST: @@ -858,26 +901,26 @@ def _reflect_function(self, func: preds.Function) -> AST: to handle them differently when reifying. """ - arg_nodes = self.reflect_child(func, func.arguments, "*") + arg_nodes = self._reflect_child(func, func.arguments, "*") return ast.Function( location=DUMMY_LOC, name=str(func.name), arguments=arg_nodes, external=0 ) @reflect_predicate.register def _reflect_pool(self, pool: preds.Pool) -> AST: - arg_nodes = self.reflect_child(pool, pool.arguments, "*") + arg_nodes = self._reflect_child(pool, pool.arguments, "*") return ast.Pool(location=DUMMY_LOC, arguments=arg_nodes) @reflect_predicate.register def _reflect_theory_terms(self, theory_terms: preds.Theory_Terms) -> AST: - return self.reflect_child(theory_terms, theory_terms.theory_term) + return self._reflect_child(theory_terms, theory_terms.theory_term) @reflect_predicate.register def _reflect_theory_sequence(self, theory_seq: preds.Theory_Sequence) -> AST: clingo_theory_sequence_type = preds.convert_enum( preds.TheorySequenceType(theory_seq.sequence_type), ast.TheorySequenceType ) - theory_term_nodes = self.reflect_child(theory_seq, theory_seq.terms, "*") + theory_term_nodes = self._reflect_child(theory_seq, theory_seq.terms, "*") return ast.TheorySequence( location=DUMMY_LOC, sequence_type=clingo_theory_sequence_type, @@ -886,7 +929,7 @@ def _reflect_theory_sequence(self, theory_seq: preds.Theory_Sequence) -> AST: @reflect_predicate.register def _reflect_theory_function(self, theory_func: preds.Theory_Function) -> AST: - arguments = self.reflect_child(theory_func, theory_func.arguments, "*") + arguments = self._reflect_child(theory_func, theory_func.arguments, "*") return ast.TheoryFunction( location=DUMMY_LOC, name=str(theory_func.name), arguments=arguments ) @@ -895,50 +938,43 @@ def _reflect_theory_function(self, theory_func: preds.Theory_Function) -> AST: def _reflect_theory_operators( self, theory_operators: preds.Theory_Operators ) -> AST: - return self.reflect_child(theory_operators, theory_operators.operator) + return theory_operators.operator + + @reflect_predicate.register + def _reflect_theory_unparsed_term_elements( + self, elements: preds.Theory_Unparsed_Term_Elements + ) -> AST: + reflected_operators = self._reflect_child(elements, elements.operators, "*") + reflected_term = self._reflect_child(elements, elements.term) + return ast.TheoryUnparsedTermElement( + operators=reflected_operators, term=reflected_term + ) @reflect_predicate.register def _reflect_theory_unparsed_term( self, theory_unparsed_term: preds.Theory_Unparsed_Term ) -> AST: - # child_elements = self._get_child_facts( - # theory_unparsed_term, theory_unparsed_term.elements - # ) - # if len(child_elements) == 0: - # msg = ( - # f"Error finding child facts of predicate '{theory_unparsed_term}'.\n" - # "Found no child 'theory_unparsed_term_elements' facts with identifier " - # f"matching {theory_unparsed_term.elements}, expected at least one." - # ) - # raise ChildQueryError(msg) - # reflected_elements = [] - # for element in child_elements: - # child_operators = self._get_child_facts(element, element.ope) - # reflected_operators = [operator.operator for operator in child_operators] - # reflected_term = self._reflect_child_pred(element, element.term) - # reflected_element = ast.TheoryUnparsedTermElement( - # operators=reflected_operators, term=reflected_term - # ) - # reflected_elements.append(reflected_element) - # return ast.TheoryUnparsedTerm(location=DUMMY_LOC, elements=reflected_elements) - pass + reflected_elements = self._reflect_child( + theory_unparsed_term, theory_unparsed_term.elements, "*" + ) + return ast.TheoryUnparsedTerm(location=DUMMY_LOC, elements=reflected_elements) @reflect_predicate.register def _reflect_guard(self, guard: preds.Guard) -> AST: clingo_operator = preds.convert_enum( preds.ComparisonOperator(guard.comparison), ast.ComparisonOperator ) - reflected_guard = self.reflect_child(guard, guard.term) + reflected_guard = self._reflect_child(guard, guard.term) return ast.Guard(comparison=clingo_operator, term=reflected_guard) @reflect_predicate.register def _reflect_guards(self, guards: preds.Guards) -> AST: - return self.reflect_child(guards, guards.guard) + return self._reflect_child(guards, guards.guard) @reflect_predicate.register def _reflect_comparison(self, comparison: preds.Comparison) -> AST: - term_node = self.reflect_child(comparison, comparison.term) - guard_nodes = self.reflect_child(comparison, comparison.guards, "+") + term_node = self._reflect_child(comparison, comparison.term) + guard_nodes = self._reflect_child(comparison, comparison.guards, "+") return ast.Comparison(term=term_node, guards=guard_nodes) @reflect_predicate.register @@ -954,42 +990,38 @@ def _reflect_boolean_constant(self, bool_const: preds.Boolean_Constant) -> AST: @reflect_predicate.register def _reflect_symbolic_atom(self, atom: preds.Symbolic_Atom) -> AST: - reflected_symbol = self.reflect_child(atom, atom.symbol) + reflected_symbol = self._reflect_child(atom, atom.symbol) return ast.SymbolicAtom(symbol=reflected_symbol) @reflect_predicate.register def _reflect_literal(self, lit: preds.Literal) -> AST: clingo_sign = preds.convert_enum(preds.Sign(lit.sig), ast.Sign) - reflected_atom = self.reflect_child(lit, lit.atom) + reflected_atom = self._reflect_child(lit, lit.atom) return ast.Literal(location=DUMMY_LOC, sign=clingo_sign, atom=reflected_atom) @reflect_predicate.register def _reflect_literals(self, literals: preds.Literals) -> AST: - self.reflect_child(literals, literals.literal) + return self._reflect_child(literals, literals.literal) @reflect_predicate.register def _reflect_conditional_literal(self, cond_lit: preds.Conditional_Literal) -> AST: - reflected_literal = self.reflect_child(cond_lit, cond_lit.literal) - reflected_condition = self.reflect_child( - cond_lit, cond_lit.condition, "*" - ) + reflected_literal = self._reflect_child(cond_lit, cond_lit.literal) + reflected_condition = self._reflect_child(cond_lit, cond_lit.condition, "*") return ast.ConditionalLiteral( location=DUMMY_LOC, literal=reflected_literal, condition=reflected_condition ) @reflect_predicate.register def _reflect_agg_elements(self, agg_elements: preds.Agg_Elements) -> AST: - return self.reflect_child(agg_elements, agg_elements.element) + return self._reflect_child(agg_elements, agg_elements.element) @reflect_predicate.register def _reflect_aggregate(self, aggregate: preds.Aggregate) -> AST: - reflected_left_guard = self.reflect_child( - aggregate, aggregate.left_guard, "?" - ) - reflected_right_guard = ( - self.reflect_child(aggregate, aggregate.right_guard, "?"), + reflected_left_guard = self._reflect_child(aggregate, aggregate.left_guard, "?") + reflected_right_guard = self._reflect_child( + aggregate, aggregate.right_guard, "?" ) - reflected_elements = self.reflect_child(aggregate, aggregate.elements, "*") + reflected_elements = self._reflect_child(aggregate, aggregate.elements, "*") return ast.Aggregate( location=DUMMY_LOC, left_guard=reflected_left_guard, @@ -1000,9 +1032,7 @@ def _reflect_aggregate(self, aggregate: preds.Aggregate) -> AST: @reflect_predicate.register def _reflect_theory_guard(self, theory_guard: preds.Theory_Guard) -> AST: reflected_operator_name = theory_guard.operator_name - reflected_theory_term = self.reflect_child( - theory_guard, theory_guard.term - ) + reflected_theory_term = self._reflect_child(theory_guard, theory_guard.term) return ast.TheoryGuard( operator_name=str(reflected_operator_name), term=reflected_theory_term ) @@ -1011,21 +1041,17 @@ def _reflect_theory_guard(self, theory_guard: preds.Theory_Guard) -> AST: def _reflect_theory_atom_elements( self, elements: preds.Theory_Atom_Elements ) -> AST: - reflected_terms = self.reflect_child(elements, elements.terms, "*") - reflected_condition = self.reflect_child( - elements, elements.condition, "*" - ) + reflected_terms = self._reflect_child(elements, elements.terms, "*") + reflected_condition = self._reflect_child(elements, elements.condition, "*") return ast.TheoryAtomElement( terms=reflected_terms, condition=reflected_condition ) @reflect_predicate.register def _reflect_theory_atom(self, theory_atom: preds.Theory_Atom) -> AST: - reflected_syb_atom = self.reflect_child(theory_atom, theory_atom.atom) - reflected_elements = self.reflect_child( - theory_atom, theory_atom.elements, "*" - ) - reflected_guard = self.reflect_child(theory_atom, theory_atom.guard, "?") + reflected_syb_atom = self._reflect_child(theory_atom, theory_atom.atom) + reflected_elements = self._reflect_child(theory_atom, theory_atom.elements, "*") + reflected_guard = self._reflect_child(theory_atom, theory_atom.guard, "?") return ast.TheoryAtom( location=DUMMY_LOC, term=reflected_syb_atom.symbol, @@ -1035,26 +1061,20 @@ def _reflect_theory_atom(self, theory_atom: preds.Theory_Atom) -> AST: @reflect_predicate.register def _reflect_body_agg_elements(self, elements: preds.Body_Agg_Elements) -> AST: - reflected_terms = self.reflect_child(elements, elements.terms, "*") - reflected_condition = self.reflect_child( - elements, elements.condition, "*" - ) + reflected_terms = self._reflect_child(elements, elements.terms, "*") + reflected_condition = self._reflect_child(elements, elements.condition, "*") return ast.BodyAggregateElement( terms=reflected_terms, condition=reflected_condition ) @reflect_predicate.register def _reflect_body_aggregate(self, aggregate: preds.Body_Aggregate) -> AST: - reflected_left_guard = self.reflect_child( - aggregate, aggregate.left_guard, "?" - ) + reflected_left_guard = self._reflect_child(aggregate, aggregate.left_guard, "?") reflected_agg_function = preds.convert_enum( preds.AggregateFunction(aggregate.function), ast.AggregateFunction ) - reflected_elements = self.reflect_child( - aggregate, aggregate.elements, "*" - ) - reflected_right_guard = self.reflect_child( + reflected_elements = self._reflect_child(aggregate, aggregate.elements, "*") + reflected_right_guard = self._reflect_child( aggregate, aggregate.right_guard, "?" ) return ast.BodyAggregate( @@ -1067,34 +1087,28 @@ def _reflect_body_aggregate(self, aggregate: preds.Body_Aggregate) -> AST: @reflect_predicate.register def _reflect_body_literals(self, body_literals: preds.Body_Literals) -> AST: - return self.reflect_child(body_literals, body_literals.body_literal) - + return self._reflect_child(body_literals, body_literals.body_literal) + @reflect_predicate.register def _reflect_body_literal(self, body_lit: preds.Body_Literal) -> AST: return self._reflect_literal(body_lit) @reflect_predicate.register def _reflect_head_agg_elements(self, elements: preds.Head_Agg_Elements) -> AST: - reflected_terms = self.reflect_child(elements, elements.terms, "*") - reflected_condition = self.reflect_child( - elements, elements.condition - ) + reflected_terms = self._reflect_child(elements, elements.terms, "*") + reflected_condition = self._reflect_child(elements, elements.condition) return ast.HeadAggregateElement( terms=reflected_terms, condition=reflected_condition ) @reflect_predicate.register def _reflect_head_aggregate(self, aggregate: preds.Head_Aggregate) -> AST: - reflected_left_guard = self.reflect_child( - aggregate, aggregate.left_guard, "?" - ) + reflected_left_guard = self._reflect_child(aggregate, aggregate.left_guard, "?") reflected_agg_function = preds.convert_enum( preds.AggregateFunction(aggregate.function), ast.AggregateFunction ) - reflected_elements = self.reflect_child( - aggregate, aggregate.elements, "*" - ) - reflected_right_guard = self.reflect_child( + reflected_elements = self._reflect_child(aggregate, aggregate.elements, "*") + reflected_right_guard = self._reflect_child( aggregate, aggregate.right_guard, "?" ) return ast.HeadAggregate( @@ -1105,30 +1119,31 @@ def _reflect_head_aggregate(self, aggregate: preds.Head_Aggregate) -> AST: right_guard=reflected_right_guard, ) + @reflect_predicate.register + def _reflect_conditional_literals( + self, cond_lits: preds.Conditional_Literals + ) -> AST: + return self._reflect_child(cond_lits, cond_lits.conditional_literal) + @reflect_predicate.register def _reflect_disjunction(self, disjunction: preds.Disjunction) -> AST: - return self.reflect_child(disjunction, disjunction.conditional_literal) + reflected_elements = self._reflect_child(disjunction, disjunction.elements, "*") + return ast.Disjunction(location=DUMMY_LOC, elements=reflected_elements) @reflect_predicate.register def _reflect_rule(self, rule: preds.Rule) -> AST: """Reflect a Rule fact into a Rule node.""" - if isinstance(rule.head, preds.Disjunction1): - head_node = ast.Disjunction( - location=DUMMY_LOC, - elements=self.reflect_child(rule, rule.head, "*"), - ) - else: - head_node = self.reflect_child(rule, rule.head) - body_nodes = self.reflect_child(rule, rule.body, "*") - return ast.Rule(location=DUMMY_LOC, head=head_node, body=body_nodes) + reflected_head = self._reflect_child(rule, rule.head) + reflected_body = self._reflect_child(rule, rule.body, "*") + return ast.Rule(location=DUMMY_LOC, head=reflected_head, body=reflected_body) @reflect_predicate.register def _reflect_statements(self, statements: preds.Statements) -> AST: - return self.reflect_child(statements, statements.statement) + return self._reflect_child(statements, statements.statement) @reflect_predicate.register def _reflect_constants(self, constants: preds.Constants) -> AST: - return self.reflect_child(constants, constants.constant) + return self._reflect_child(constants, constants.constant) @reflect_predicate.register def _reflect_program(self, program: preds.Program) -> Sequence[AST]: @@ -1137,21 +1152,21 @@ def _reflect_program(self, program: preds.Program) -> Sequence[AST]: """ subprogram = [] - parameter_nodes = self.reflect_child(program, program.parameters, "*") + parameter_nodes = self._reflect_child(program, program.parameters, "*") subprogram.append( ast.Program( location=DUMMY_LOC, name=str(program.name), parameters=parameter_nodes ) ) - statement_nodes = self.reflect_child(program, program.statements, "*") + statement_nodes = self._reflect_child(program, program.statements, "*") subprogram.extend(statement_nodes) return subprogram @reflect_predicate.register def _reflect_external(self, external: preds.External) -> AST: """Reflect an External fact into an External node.""" - symb_atom_node = self.reflect_child(external, external.atom) - body_nodes = self.reflect_child(external, external.body, "*") + symb_atom_node = self._reflect_child(external, external.atom) + body_nodes = self._reflect_child(external, external.body, "*") ext_type = ast.SymbolicTerm( location=DUMMY_LOC, symbol=symbol.Function(name=str(external.external_type), arguments=[]), diff --git a/src/renopro/utils/__init__.py b/src/renopro/utils/__init__.py index efbf5c4..b0d5729 100644 --- a/src/renopro/utils/__init__.py +++ b/src/renopro/utils/__init__.py @@ -1,3 +1,10 @@ """ Utilities. """ +from typing import NoReturn + + +def assert_never(value: NoReturn) -> NoReturn: + """Function to help mypy make exhaustiveness check when + e.g. dispatching on enum values.""" + assert False, f"This code should never be reached, got: {value}" diff --git a/tests/asp/reify_reflect/malformed_ast/multiple_child.lp b/tests/asp/reify_reflect/malformed_ast/one_expected_multiple_found.lp similarity index 100% rename from tests/asp/reify_reflect/malformed_ast/multiple_child.lp rename to tests/asp/reify_reflect/malformed_ast/one_expected_multiple_found.lp diff --git a/tests/asp/reify_reflect/malformed_ast/missing_child.lp b/tests/asp/reify_reflect/malformed_ast/one_expected_zero_found.lp similarity index 100% rename from tests/asp/reify_reflect/malformed_ast/missing_child.lp rename to tests/asp/reify_reflect/malformed_ast/one_expected_zero_found.lp diff --git a/tests/asp/reify_reflect/malformed_ast/missing_guard_in_comparison.lp b/tests/asp/reify_reflect/malformed_ast/one_or_more_expected_found_zero.lp similarity index 100% rename from tests/asp/reify_reflect/malformed_ast/missing_guard_in_comparison.lp rename to tests/asp/reify_reflect/malformed_ast/one_or_more_expected_found_zero.lp diff --git a/tests/asp/reify_reflect/malformed_ast/zero_or_more_expected_multiple_found.lp b/tests/asp/reify_reflect/malformed_ast/zero_or_more_expected_multiple_found.lp new file mode 100644 index 0000000..d62203e --- /dev/null +++ b/tests/asp/reify_reflect/malformed_ast/zero_or_more_expected_multiple_found.lp @@ -0,0 +1,10 @@ +% malformed ast with multiple aggregate guards. + +program("base",constants(0),statements(1)). +statements(1,0,rule(2)). +rule(2,aggregate(3),body_literals(9)). +aggregate(3,guard(4),agg_elements(7),guard(8)). +guard(4,"<=",number(5)). +number(5,1). +guard(4,"<",number(6)). +number(6,3). diff --git a/tests/asp/reify_reflect/well_formed_ast/aggregate2.lp b/tests/asp/reify_reflect/well_formed_ast/aggregate2.lp new file mode 100644 index 0000000..9938c82 --- /dev/null +++ b/tests/asp/reify_reflect/well_formed_ast/aggregate2.lp @@ -0,0 +1,17 @@ +% reified fact representation of program: +% #program base. +% 1 <= { a } < 3. + +program("base",constants(0),statements(1)). +statements(1,0,rule(2)). +rule(2,aggregate(3),body_literals(15)). +aggregate(3,guard(4),agg_elements(6),guard(13)). +guard(4,"<=",number(5)). +number(5,1). +agg_elements(6,0,conditional_literal(7)). +conditional_literal(7,literal(8),literals(12)). +literal(8,"pos",symbolic_atom(9)). +symbolic_atom(9,function(10)). +function(10,a,terms(11)). +guard(13,"<",number(14)). +number(14,3). diff --git a/tests/asp/reify_reflect/well_formed_ast/disjunction.lp b/tests/asp/reify_reflect/well_formed_ast/disjunction.lp index 907ef49..b8da52c 100644 --- a/tests/asp/reify_reflect/well_formed_ast/disjunction.lp +++ b/tests/asp/reify_reflect/well_formed_ast/disjunction.lp @@ -4,22 +4,23 @@ program("base",constants(0),statements(1)). statements(1,0,rule(2)). -rule(2,disjunction(3),body_literals(24)). -disjunction(3,0,conditional_literal(4)). -conditional_literal(4,literal(5),literals(9)). -literal(5,"pos",symbolic_atom(6)). -symbolic_atom(6,function(7)). -function(7,a,terms(8)). -disjunction(3,1,conditional_literal(10)). -conditional_literal(10,literal(11),literals(15)). -literal(11,"pos",symbolic_atom(12)). -symbolic_atom(12,function(13)). -function(13,b,terms(14)). -literals(15,0,literal(16)). -literal(16,"pos",symbolic_atom(17)). -symbolic_atom(17,function(18)). -function(18,c,terms(19)). -literals(15,1,literal(20)). -literal(20,"not",symbolic_atom(21)). -symbolic_atom(21,function(22)). -function(22,d,terms(23)). +rule(2,disjunction(3),body_literals(25)). +disjunction(3,conditional_literals(4)). +conditional_literals(4,0,conditional_literal(5)). +conditional_literal(5,literal(6),literals(10)). +literal(6,"pos",symbolic_atom(7)). +symbolic_atom(7,function(8)). +function(8,a,terms(9)). +conditional_literals(4,1,conditional_literal(11)). +conditional_literal(11,literal(12),literals(16)). +literal(12,"pos",symbolic_atom(13)). +symbolic_atom(13,function(14)). +function(14,b,terms(15)). +literals(16,0,literal(17)). +literal(17,"pos",symbolic_atom(18)). +symbolic_atom(18,function(19)). +function(19,c,terms(20)). +literals(16,1,literal(21)). +literal(21,"not",symbolic_atom(22)). +symbolic_atom(22,function(23)). +function(23,d,terms(24)). diff --git a/tests/asp/reify_reflect/well_formed_ast/theory_unparsed_term.lp b/tests/asp/reify_reflect/well_formed_ast/theory_unparsed_term.lp index 7f18322..01d2c76 100644 --- a/tests/asp/reify_reflect/well_formed_ast/theory_unparsed_term.lp +++ b/tests/asp/reify_reflect/well_formed_ast/theory_unparsed_term.lp @@ -4,17 +4,17 @@ program("base",constants(0),statements(1)). statements(1,0,rule(2)). -rule(2,theory_atom(3),body_literals(15)). -theory_atom(3,symbolic_atom(4),theory_atom_elements(7),theory_guard(14)). +rule(2,theory_atom(3),body_literals(17)). +theory_atom(3,symbolic_atom(4),theory_atom_elements(7),theory_guard(16)). symbolic_atom(4,function(5)). function(5,a,terms(6)). -theory_atom_elements(7,0,theory_terms(8),literals(13)). +theory_atom_elements(7,0,theory_terms(8),literals(15)). theory_terms(8,0,theory_unparsed_term(9)). theory_unparsed_term(9,theory_unparsed_term_elements(10)). theory_unparsed_term_elements(10,0,theory_operators(11),number(12)). theory_operators(11,0,"+"). number(12,1). -theory_unparsed_term_elements(10,1,theory_operators(12),string(13)). -theory_operators(12,0,"!"). -theory_operators(12,1,"-"). -string(13,"b"). +theory_unparsed_term_elements(10,1,theory_operators(13),string(14)). +theory_operators(13,0,"!-"). +theory_operators(13,1,">"). +string(14,"b"). diff --git a/tests/test_reify_reflect.py b/tests/test_reify_reflect.py index 709caa9..cb2a7e1 100644 --- a/tests/test_reify_reflect.py +++ b/tests/test_reify_reflect.py @@ -213,25 +213,25 @@ def test_rast_node_failure(self): with self.assertRaisesRegex(TypeError, expected_regex=regex): rast.reify_node(not_node) - def test_child_query_error_none_found(self): + def test_child_query_error_one_expected_zero_found(self): """Reflection of a parent fact that expects a singe child fact but finds none should fail with an informative error message. """ rast = ReifiedAST() - rast.add_reified_files([malformed_ast_files / "missing_child.lp"]) + rast.add_reified_files([malformed_ast_files / "one_expected_zero_found.lp"]) regex = r"(?s).*atom\(4,function\(5\)\).*function\(5\).*found 0.*" with self.assertRaisesRegex(ChildQueryError, expected_regex=regex): rast.reflect() - def test_child_query_error_multiple_found(self): + def test_child_query_error_one_expected_multiple_found(self): """Reflection of a parent fact that expects a singe child fact but finds multiple should fail with an informative error message. """ rast = ReifiedAST() - rast.add_reified_files([malformed_ast_files / "multiple_child.lp"]) + rast.add_reified_files([malformed_ast_files / "one_expected_multiple_found.lp"]) regex = r"(?s).*atom\(4,function\(5\)\).*function\(5\).*found 2.*" with self.assertRaisesRegex(ChildQueryError, expected_regex=regex): rast.reflect() @@ -248,17 +248,30 @@ def test_children_query_error_multiple_in_same_position(self): with self.assertRaisesRegex(ChildrenQueryError, expected_regex=regex): rast.reflect() - def test_comparison_no_guard_found(self): - """Reflection of a comparison fact with an empty guard tuple should fail.""" + def test_children_query_error_one_or_more_expected_zero_found(self): + """Reflection of a child facts where one or more facts are expected + should fail with an informative error message when zero are found.""" rast = ReifiedAST() - rast.add_reified_files([malformed_ast_files / "missing_guard_in_comparison.lp"]) + rast.add_reified_files([malformed_ast_files / "one_or_more_expected_found_zero.lp"]) regex = ( r"(?s).*comparison\(4,number\(5\),guards\(6\)\).*" - # r".*Expected 1 or more.*guards\(6\).*." + r".*Expected 1 or more.*guards\(6\).*." ) with self.assertRaisesRegex(ChildrenQueryError, expected_regex=regex): rast.reflect() + def test_child_query_error_zero_or_one_expected_multiple_found(self): + """Reflection of a child fact where zero or one fact is expected + should fail with an informative error message when multiple are found.""" + rast = ReifiedAST() + rast.add_reified_files([malformed_ast_files / "zero_or_more_expected_multiple_found.lp"]) + regex = ( + r"(?s).*aggregate\(3,guard\(4\),agg_elements\(7\),guard\(8\)\).*" + r".*Expected 0 or 1.*guard\(4\).*found 2." + ) + with self.assertRaisesRegex(ChildQueryError, expected_regex=regex): + rast.reflect() + class TestReifyReflectAggTheory(TestReifyReflect): """Test cases for reification and reflection of aggregates and @@ -270,6 +283,8 @@ def test_rast_aggregate(self): self.assertReifyEqual("1 {a: b; c}.", ["aggregate.lp"]) with self.subTest(operation="reflect"): self.assertReflectEqual("1 <= { a: b; c }.", ["aggregate.lp"]) + self.setUp() + self.assertReifyReflectEqual("1 <= { a } < 3.", ["aggregate2.lp"]) def test_rast_simple_theory_atom(self): """Test reification and reflection of simple theory atoms @@ -299,7 +314,9 @@ def test_rast_theory_term(self): def test_rast_theory_unparsed_term(self): "Test reification and reflection of an unparsed theory term." - self.assertReifyReflectEqual('&a { +1!-"b" }.', ["theory_unparsed_term.lp"]) + self.assertReifyReflectEqual( + '&a { (+ 1 !- > "b") }.', ["theory_unparsed_term.lp"] + ) def test_rast_head_aggregate(self): "Test reification and reflection of a head aggregate."