From a3a1cd5f0ac5ad0c12071b8e23fe062ac6fd1e8d Mon Sep 17 00:00:00 2001 From: Amade Nemes Date: Sun, 24 Sep 2023 14:34:48 +0200 Subject: [PATCH] Reorganize code. --- src/renopro/enum_fields.py | 145 ++++++++++++++++ src/renopro/predicates.py | 285 ++----------------------------- src/renopro/rast.py | 90 ++-------- src/renopro/utils/clorm_utils.py | 206 ++++++++++++++++++++++ 4 files changed, 374 insertions(+), 352 deletions(-) create mode 100644 src/renopro/enum_fields.py create mode 100644 src/renopro/utils/clorm_utils.py diff --git a/src/renopro/enum_fields.py b/src/renopro/enum_fields.py new file mode 100644 index 0000000..93c2370 --- /dev/null +++ b/src/renopro/enum_fields.py @@ -0,0 +1,145 @@ +"Definitions of enum fields to be used in AST predicates." +import enum + +from clorm import StringField, ConstantField + +from renopro.utils.clorm_utils import define_enum_field + + +class UnaryOperator(str, enum.Enum): + "String enum of clingo's unary operators." + Absolute = "||" # For taking the absolute value. + Minus = "-" # For unary minus and classical negation. + Negation = "~" # For bitwise negation + + +UnaryOperatorField = define_enum_field( + parent_field=StringField, enum_class=UnaryOperator, name="UnaryOperatorField" +) + + +class BinaryOperator(str, enum.Enum): + "String enum of clingo's binary operators." + And = "&" # bitwise and + Division = "/" # arithmetic division + Minus = "-" # arithmetic subtraction + Modulo = "%" # arithmetic modulo + Multiplication = "*" # arithmetic multiplication + Or = "?" # bitwise or + Plus = "+" # arithmetic addition + Power = "**" # arithmetic exponentiation + XOr = "^" # bitwise exclusive or + + +BinaryOperatorField = define_enum_field( + parent_field=StringField, enum_class=BinaryOperator, name="BinaryOperatorField" +) + + +class ComparisonOperator(str, enum.Enum): + """ + String enumeration of clingo's comparison operators. + """ + + Equal = "=" + GreaterEqual = ">=" + GreaterThan = ">" + LessEqual = "<=" + LessThan = "<" + NotEqual = "!=" + + +ComparisonOperatorField = define_enum_field( + parent_field=StringField, + enum_class=ComparisonOperator, + name="ComparisonOperatorField", +) + + +class TheorySequenceType(str, enum.Enum): + """String enum of theory sequence types.""" + + List = "[]" + """ + For sequences enclosed in brackets. + """ + Set = "{}" + """ + For sequences enclosed in braces. + """ + Tuple = "()" + """ + For sequences enclosed in parenthesis. + """ + + +TheorySequenceTypeField = define_enum_field( + parent_field=StringField, + enum_class=TheorySequenceType, + name="TheorySequenceTypeField", +) + + +class Sign(str, enum.Enum): + """String enum of possible sign of a literal.""" + + DoubleNegation = "not not" + """ + For double negated literals (with prefix `not not`) + """ + Negation = "not" + """ + For negative literals (with prefix `not`). + """ + NoSign = "pos" + """ + For positive literals. + """ + + +SignField = define_enum_field( + parent_field=StringField, enum_class=Sign, name="SignField" +) + + +class AggregateFunction(str, enum.Enum): + "String enum of clingo's aggregate functions." + Count = "#count" + Max = "#max" + Min = "#min" + Sum = "#sum" + SumPlus = "#sum+" + + +AggregateFunctionField = define_enum_field( + parent_field=StringField, + enum_class=AggregateFunction, + name="AggregateFunctionField", +) + + +class TheoryOperatorType(str, enum.Enum): + "String enum of clingo's theory definition types" + BinaryLeft = "binary_left" + BinaryRight = "binary_right" + Unary = "unary" + + +TheoryOperatorTypeField = define_enum_field( + parent_field=ConstantField, + enum_class=TheoryOperatorType, + name="TheoryOperatorTypeField", +) + + +class TheoryAtomType(str, enum.Enum): + "String enum of clingo's theory atom types." + Any = "any" + Body = "body" + Directive = "directive" + Head = "head" + + +TheoryAtomTypeField = define_enum_field( + parent_field=ConstantField, enum_class=TheoryAtomType, name="TheoryAtomTypeField" +) diff --git a/src/renopro/predicates.py b/src/renopro/predicates.py index 73a09b5..6e414aa 100644 --- a/src/renopro/predicates.py +++ b/src/renopro/predicates.py @@ -1,14 +1,11 @@ # pylint: disable=too-many-lines """Definitions of AST elements as clorm predicates.""" -import enum -import inspect import re from itertools import count from types import new_class -from typing import Sequence, Type, TypeVar, Union +from typing import Union from clorm import ( - BaseField, ComplexTerm, ConstantField, IntegerField, @@ -19,111 +16,19 @@ ) from clorm.orm.core import _PredicateMeta -id_count = count() - - -def combine_fields( - fields: Sequence[Type[BaseField]], *, name: str = "" -) -> Type[BaseField]: - """Factory function that returns a field sub-class that combines - other fields lazily. - - Essentially the same as the combine_fields defined in the clorm - package, but exposes a 'fields' attrible, allowing us to add - additional fields after the initial combination of fields by - appending to the 'fields' attribute of the combined field. - - """ - subclass_name = name if name else "AnonymousCombinedBaseField" - - # Must combine at least two fields otherwise it doesn't make sense - for f in fields: - if not inspect.isclass(f) or not issubclass(f, BaseField): - raise TypeError("{f} is not BaseField or a sub-class.") - - fields = list(fields) - - def _pytocl(value): - for f in fields: - try: - return f.pytocl(value) - except (TypeError, ValueError, AttributeError): - pass - raise TypeError(f"No combined pytocl() match for value {value}.") - - def _cltopy(symbol): - for f in fields: - try: - return f.cltopy(symbol) - except (TypeError, ValueError): - pass - raise TypeError( - ( - f"Object '{symbol}' ({type(symbol)}) failed to unify " - f"with {subclass_name}." - ) - ) - - def body(ns): - ns.update({"fields": fields, "pytocl": _pytocl, "cltopy": _cltopy}) - - return new_class(subclass_name, (BaseField,), {}, body) - - -def define_enum_field( - parent_field: Type[BaseField], enum_class: Type[enum.Enum], *, name: str = "" -) -> Type[BaseField]: # nocoverage - """Factory function that returns a BaseField sub-class for an - Enum. Essentially the same as the one defined in clorm, but stores - the enum that defines the field under attribute 'enum' for later - use. - - Enums are part of the standard library since Python 3.4. This method - provides an alternative to using refine_field() to provide a restricted set - of allowable values. - - Example: - .. code-block:: python - - class IO(str,Enum): - IN="in" - OUT="out" - - # A field that unifies against ASP constants "in" and "out" - IOField = define_enum_field(ConstantField,IO) - - Positional argument: - - field_class: the field that is being sub-classed - - enum_class: the Enum class - - Optional keyword-only arguments: - - name: name for new class (default: anonymously generated). - - """ - subclass_name = name if name else parent_field.__name__ + "_Restriction" - if not inspect.isclass(parent_field) or not issubclass(parent_field, BaseField): - raise TypeError(f"{parent_field} is not a subclass of BaseField") - - if not inspect.isclass(enum_class) or not issubclass(enum_class, enum.Enum): - raise TypeError(f"{enum_class} is not a subclass of enum.Enum") - - values = set(i.value for i in enum_class) - - def _pytocl(py): - val = py.value - if val not in values: - raise ValueError( - f"'{val}' is not a valid value of enum class '{enum_class.__name__}'" - ) - return val - - def body(ns): - ns.update({"pytocl": _pytocl, "cltopy": enum_class, "enum": enum_class}) +from renopro.utils.clorm_utils import combine_fields +from renopro.enum_fields import ( + SignField, + UnaryOperatorField, + BinaryOperatorField, + TheorySequenceTypeField, + ComparisonOperatorField, + AggregateFunctionField, + TheoryOperatorTypeField, + TheoryAtomTypeField +) - return new_class(subclass_name, (parent_field,), {}, body) +id_count = count() # by default we use integer identifiers, but allow arbitrary symbols as well @@ -132,170 +37,6 @@ def body(ns): IdentifierField = IdentifierField(default=lambda: next(id_count)) # type: ignore -# Enum field definitions - - -class UnaryOperator(str, enum.Enum): - "String enum of clingo's unary operators." - Absolute = "||" # For taking the absolute value. - Minus = "-" # For unary minus and classical negation. - Negation = "~" # For bitwise negation - - -UnaryOperatorField = define_enum_field( - parent_field=StringField, enum_class=UnaryOperator, name="UnaryOperatorField" -) - - -class BinaryOperator(str, enum.Enum): - "String enum of clingo's binary operators." - And = "&" # bitwise and - Division = "/" # arithmetic division - Minus = "-" # arithmetic subtraction - Modulo = "%" # arithmetic modulo - Multiplication = "*" # arithmetic multiplication - Or = "?" # bitwise or - Plus = "+" # arithmetic addition - Power = "**" # arithmetic exponentiation - XOr = "^" # bitwise exclusive or - - -BinaryOperatorField = define_enum_field( - parent_field=StringField, enum_class=BinaryOperator, name="BinaryOperatorField" -) - - -class ComparisonOperator(str, enum.Enum): - """ - String enumeration of clingo's comparison operators. - """ - - Equal = "=" - GreaterEqual = ">=" - GreaterThan = ">" - LessEqual = "<=" - LessThan = "<" - NotEqual = "!=" - - -ComparisonOperatorField = define_enum_field( - parent_field=StringField, - enum_class=ComparisonOperator, - name="ComparisonOperatorField", -) - - -class TheorySequenceType(str, enum.Enum): - """String enum of theory sequence types.""" - - List = "[]" - """ - For sequences enclosed in brackets. - """ - Set = "{}" - """ - For sequences enclosed in braces. - """ - Tuple = "()" - """ - For sequences enclosed in parenthesis. - """ - - -TheorySequenceTypeField = define_enum_field( - parent_field=StringField, - enum_class=TheorySequenceType, - name="TheorySequenceTypeField", -) - - -class Sign(str, enum.Enum): - """String enum of possible sign of a literal.""" - - DoubleNegation = "not not" - """ - For double negated literals (with prefix `not not`) - """ - Negation = "not" - """ - For negative literals (with prefix `not`). - """ - NoSign = "pos" - """ - For positive literals. - """ - - -SignField = define_enum_field( - parent_field=StringField, enum_class=Sign, name="SignField" -) - - -class AggregateFunction(str, enum.Enum): - "String enum of clingo's aggregate functions." - Count = "#count" - Max = "#max" - Min = "#min" - Sum = "#sum" - SumPlus = "#sum+" - - -AggregateFunctionField = define_enum_field( - parent_field=StringField, - enum_class=AggregateFunction, - name="AggregateFunctionField", -) - - -class TheoryOperatorType(str, enum.Enum): - "String enum of clingo's theory definition types" - BinaryLeft = "binary_left" - BinaryRight = "binary_right" - Unary = "unary" - - -TheoryOperatorTypeField = define_enum_field( - parent_field=ConstantField, - enum_class=TheoryOperatorType, - name="TheoryOperatorTypeField", -) - - -class TheoryAtomType(str, enum.Enum): - "String enum of clingo's theory atom types." - Any = "any" - Body = "body" - Directive = "directive" - Head = "head" - - -TheoryAtomTypeField = define_enum_field( - parent_field=ConstantField, enum_class=TheoryAtomType, name="TheoryAtomTypeField" -) - - -A = TypeVar("A", bound=enum.Enum) -B = TypeVar("B", bound=enum.Enum) - - -def convert_enum(enum_member: A, other_enum: Type[B]) -> B: - """Given an enum_member, convert it to the other_enum member of - the same name. - """ - # enum_type = type(enum_member) - # cast to enum - needed as enum members stored in a clingo AST object - # gets cast to it's raw value for some reason - # enum_member = enum_type(enum_member) - try: - return other_enum[enum_member.name] - except KeyError as exc: # nocoverage - msg = ( - f"Enum {other_enum} has no corresponding member " - f"with name {enum_member.name}" - ) - raise ValueError(msg) from exc - - # Metaclass shenanigans to dynamically create unary versions of AST predicates, # which are used to identify child AST facts diff --git a/src/renopro/rast.py b/src/renopro/rast.py index 05a1526..8231fcd 100644 --- a/src/renopro/rast.py +++ b/src/renopro/rast.py @@ -1,9 +1,6 @@ # pylint: disable=too-many-lines """Module implementing reification and de-reification of non-ground programs""" -import inspect import logging -import re -from contextlib import AbstractContextManager from functools import singledispatchmethod from itertools import count from pathlib import Path @@ -35,15 +32,20 @@ BaseField, FactBase, Unifier, - UnifierNoMatchError, control_add_facts, parse_fact_files, parse_fact_string, ) -from thefuzz import process # type: ignore +import renopro.enum_fields as enums import renopro.predicates as preds from renopro.utils import assert_never +from renopro.utils.clorm_utils import ( + ChildQueryError, + ChildrenQueryError, + TryUnify, + convert_enum, +) from renopro.utils.logger import get_clingo_logger_callback logger = logging.getLogger(__name__) @@ -54,78 +56,6 @@ NodeAttr = Union[AST, Symbol, Sequence[Symbol], ASTSequence, str, int, StrSequence] -class ChildQueryError(Exception): - """Exception raised when a required child fact of an AST fact - cannot be found. - - """ - - -class ChildrenQueryError(Exception): - """Exception raised when the expected number child facts of an AST - fact cannot be found. - - """ - - -class TryUnify(AbstractContextManager): - """Context manager to try some operation that requires unification - of some set of ast facts. Enhance error message if unification fails. - """ - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is UnifierNoMatchError: - self.handle_unify_error(exc_value) - - @staticmethod - def handle_unify_error(error): - """Enhance UnifierNoMatchError with some more - useful error messages to help debug the reason unification failed. - - """ - unmatched = error.symbol - name2arity2pred = { - pred.meta.name: {pred.meta.arity: pred} for pred in preds.AstPreds - } - candidate = name2arity2pred.get(unmatched.name, {}).get( - len(unmatched.arguments) - ) - if candidate is None: - fuzzy_name = process.extractOne(unmatched.name, name2arity2pred.keys())[0] - signatures = [ - f"{fuzzy_name}/{arity}." for arity in name2arity2pred[fuzzy_name] - ] - msg = f"""No AST fact of matching signature found for symbol - '{unmatched}'. - Similar AST fact signatures are: - """ + "\n".join( - signatures - ) - raise UnifierNoMatchError( - inspect.cleandoc(msg), unmatched, error.predicates - ) from None - for idx, arg in enumerate(unmatched.arguments): - # This is very hacky. Should ask Dave for a better - # solution, if there is one. - arg_field = candidate[idx]._field # pylint: disable=protected-access - arg_field_str = re.sub(r"\(.*?\)", "", str(arg_field)) - try: - arg_field.cltopy(arg) - except (TypeError, ValueError): - msg = f"""Cannot unify symbol - '{unmatched}' - to only candidate AST fact of matching signature - {candidate.meta.name}/{candidate.meta.arity} - due to failure to unify symbol's argument - '{arg}' - against the corresponding field - '{arg_field_str}'.""" - raise UnifierNoMatchError( - inspect.cleandoc(msg), unmatched, (candidate,) - ) from None - raise RuntimeError("Code should be unreachable") # nocoverage - - class ReifiedAST: """Class for converting between reified and non-reified representation of ASP programs. @@ -410,7 +340,7 @@ def _reify_attr(self, annotation: Type, attr: NodeAttr, field: BaseField): ) if hasattr(field, "enum"): ast_enum = getattr(ast, field.enum.__name__) - return preds.convert_enum(ast_enum(attr), field.enum) + return convert_enum(ast_enum(attr), field.enum) if annotation in [str, int]: return attr raise RuntimeError("Code should be unreachable.") # nocoverage @@ -500,7 +430,7 @@ def _reify_body_literals(self, nodes: Sequence[ast.AST]): id=body_lits1.id, position=pos, body_literal=body_lit1 ) ) - clorm_sign = preds.convert_enum(ast.Sign(lit.sign), preds.Sign) + clorm_sign = convert_enum(ast.Sign(lit.sign), enums.Sign) body_lit = preds.BodyLiteral( id=body_lit1.id, sign_=clorm_sign, atom=self.reify_node(lit.atom) ) @@ -708,7 +638,7 @@ def reflect_fact(self, fact: preds.AstPred): # nocoverage kwargs_dict.update({key: attr_override_func(fact)}) elif clorm_enum := getattr(field, "enum", None): ast_enum = getattr(ast, clorm_enum.__name__) - ast_enum_member = preds.convert_enum(field_val, ast_enum) + ast_enum_member = convert_enum(field_val, ast_enum) kwargs_dict.update({key: ast_enum_member}) elif child_type in [str, int]: kwargs_dict.update({key: field_val}) diff --git a/src/renopro/utils/clorm_utils.py b/src/renopro/utils/clorm_utils.py new file mode 100644 index 0000000..42cbe08 --- /dev/null +++ b/src/renopro/utils/clorm_utils.py @@ -0,0 +1,206 @@ +"Clorm related utility functions." +import enum +import inspect +import re +from contextlib import AbstractContextManager +from types import new_class +from typing import Sequence, Type, TypeVar + +from clorm import BaseField, UnifierNoMatchError +from thefuzz import process # type: ignore + +from renopro import predicates as preds + + +def combine_fields( + fields: Sequence[Type[BaseField]], *, name: str = "" +) -> Type[BaseField]: + """Factory function that returns a field sub-class that combines + other fields lazily. + + Essentially the same as the combine_fields defined in the clorm + package, but exposes a 'fields' attrible, allowing us to add + additional fields after the initial combination of fields by + appending to the 'fields' attribute of the combined field. + + """ + subclass_name = name if name else "AnonymousCombinedBaseField" + + # Must combine at least two fields otherwise it doesn't make sense + for f in fields: + if not inspect.isclass(f) or not issubclass(f, BaseField): + raise TypeError("{f} is not BaseField or a sub-class.") + + fields = list(fields) + + def _pytocl(value): + for f in fields: + try: + return f.pytocl(value) + except (TypeError, ValueError, AttributeError): + pass + raise TypeError(f"No combined pytocl() match for value {value}.") + + def _cltopy(symbol): + for f in fields: + try: + return f.cltopy(symbol) + except (TypeError, ValueError): + pass + raise TypeError( + ( + f"Object '{symbol}' ({type(symbol)}) failed to unify " + f"with {subclass_name}." + ) + ) + + def body(ns): + ns.update({"fields": fields, "pytocl": _pytocl, "cltopy": _cltopy}) + + return new_class(subclass_name, (BaseField,), {}, body) + + +def define_enum_field( + parent_field: Type[BaseField], enum_class: Type[enum.Enum], *, name: str = "" +) -> Type[BaseField]: # nocoverage + """Factory function that returns a BaseField sub-class for an + Enum. Essentially the same as the one defined in clorm, but stores + the enum that defines the field under attribute 'enum' for later + use. + + Enums are part of the standard library since Python 3.4. This method + provides an alternative to using refine_field() to provide a restricted set + of allowable values. + + Example: + .. code-block:: python + + class IO(str,Enum): + IN="in" + OUT="out" + + # A field that unifies against ASP constants "in" and "out" + IOField = define_enum_field(ConstantField,IO) + + Positional argument: + + field_class: the field that is being sub-classed + + enum_class: the Enum class + + Optional keyword-only arguments: + + name: name for new class (default: anonymously generated). + + """ + subclass_name = name if name else parent_field.__name__ + "_Restriction" + if not inspect.isclass(parent_field) or not issubclass(parent_field, BaseField): + raise TypeError(f"{parent_field} is not a subclass of BaseField") + + if not inspect.isclass(enum_class) or not issubclass(enum_class, enum.Enum): + raise TypeError(f"{enum_class} is not a subclass of enum.Enum") + + values = set(i.value for i in enum_class) + + def _pytocl(py): + val = py.value + if val not in values: + raise ValueError( + f"'{val}' is not a valid value of enum class '{enum_class.__name__}'" + ) + return val + + def body(ns): + ns.update({"pytocl": _pytocl, "cltopy": enum_class, "enum": enum_class}) + + return new_class(subclass_name, (parent_field,), {}, body) + + +A = TypeVar("A", bound=enum.Enum) +B = TypeVar("B", bound=enum.Enum) + + +def convert_enum(enum_member: A, other_enum: Type[B]) -> B: + """Given an enum_member, convert it to the other_enum member of + the same name. + """ + try: + return other_enum[enum_member.name] + except KeyError as exc: # nocoverage + msg = ( + f"Enum {other_enum} has no corresponding member " + f"with name {enum_member.name}" + ) + raise ValueError(msg) from exc + + +class ChildQueryError(Exception): + """Exception raised when a required child fact of an AST fact + cannot be found. + + """ + + +class ChildrenQueryError(Exception): + """Exception raised when the expected number child facts of an AST + fact cannot be found. + + """ + + +class TryUnify(AbstractContextManager): + """Context manager to try some operation that requires unification + of some set of ast facts. Enhance error message if unification fails. + """ + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is UnifierNoMatchError: + self.handle_unify_error(exc_value) + + @staticmethod + def handle_unify_error(error): + """Enhance UnifierNoMatchError with some more + useful error messages to help debug the reason unification failed. + + """ + unmatched = error.symbol + name2arity2pred = { + pred.meta.name: {pred.meta.arity: pred} for pred in preds.AstPreds + } + candidate = name2arity2pred.get(unmatched.name, {}).get( + len(unmatched.arguments) + ) + if candidate is None: + fuzzy_name = process.extractOne(unmatched.name, name2arity2pred.keys())[0] + signatures = [ + f"{fuzzy_name}/{arity}." for arity in name2arity2pred[fuzzy_name] + ] + msg = f"""No AST fact of matching signature found for symbol + '{unmatched}'. + Similar AST fact signatures are: + """ + "\n".join( + signatures + ) + raise UnifierNoMatchError( + inspect.cleandoc(msg), unmatched, error.predicates + ) from None + for idx, arg in enumerate(unmatched.arguments): + # This is very hacky. Should ask Dave for a better + # solution, if there is one. + arg_field = candidate[idx]._field # pylint: disable=protected-access + arg_field_str = re.sub(r"\(.*?\)", "", str(arg_field)) + try: + arg_field.cltopy(arg) + except (TypeError, ValueError): + msg = f"""Cannot unify symbol + '{unmatched}' + to only candidate AST fact of matching signature + {candidate.meta.name}/{candidate.meta.arity} + due to failure to unify symbol's argument + '{arg}' + against the corresponding field + '{arg_field_str}'.""" + raise UnifierNoMatchError( + inspect.cleandoc(msg), unmatched, (candidate,) + ) from None + raise RuntimeError("Code should be unreachable") # nocoverage