Skip to content

Commit

Permalink
Format code, add better error messages for failed unification.
Browse files Browse the repository at this point in the history
  • Loading branch information
namcsi committed Jun 26, 2023
1 parent 29a5147 commit 0cd4b75
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 145 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ install_requires =
importlib_metadata;python_version<'3.8'
clingo
clorm
thefuzz

[options.packages.find]
where = src
Expand Down
80 changes: 50 additions & 30 deletions src/renopro/predicates.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Definitions of AST elements as clorm predicates."""
from typing import Union, Sequence, Type
import enum
import inspect
from itertools import count
from typing import Sequence, Type, Union

from clorm import (IntegerField, Predicate, StringField, RawField,
BaseField, combine_fields, define_enum_field,
refine_field, ConstantField)
from clingo import ast
from clorm import (
BaseField,
ConstantField,
IntegerField,
Predicate,
RawField,
StringField,
Unifier,
combine_fields,
define_enum_field,
refine_field,
)

id_count = count()
next_id = lambda: next(id_count)
Expand All @@ -25,15 +34,21 @@ def make_id_predicate(ast_pred):
"""
id_pred_name = ast_pred.__name__ + "1"
id_pred = type(id_pred_name, (Predicate,),
{"id": Identifier_Field(default=next_id),
"Meta": type("Meta", tuple(), {"name": ast_pred.meta.name})})
id_pred = type(
id_pred_name,
(Predicate,),
{
"id": Identifier_Field(default=next_id),
"Meta": type("Meta", tuple(), {"name": ast_pred.meta.name}),
},
)
id_pred2ast_pred.update({id_pred: ast_pred})
return id_pred


def combine_fields_lazily(fields: Sequence[Type[BaseField]], *, name:
str = "") -> Type[BaseField]:
def combine_fields_lazily(
fields: Sequence[Type[BaseField]], *, name: str = ""
) -> Type[BaseField]:
"""Factory function that returns a field sub-class that combines
other fields lazily.
Expand Down Expand Up @@ -66,12 +81,15 @@ def _cltopy(r):
return f.cltopy(r)
except (TypeError, ValueError):
pass
raise TypeError((f"Object '{r}' ({type(r)}) failed to unify "
f"with {subclass_name}."))
raise TypeError(
(f"Object '{r}' ({type(r)}) failed to unify " f"with {subclass_name}.")
)

return type(subclass_name, (BaseField,), {"fields": fields,
"pytocl": _pytocl,
"cltopy": _cltopy})
return type(
subclass_name,
(BaseField,),
{"fields": fields, "pytocl": _pytocl, "cltopy": _cltopy},
)


class String(Predicate):
Expand All @@ -98,8 +116,9 @@ class Variable(Predicate):
Variable1 = make_id_predicate(Variable)


Term_Field = combine_fields_lazily([String1.Field, Number1.Field,
Variable1.Field], name="Term")
Term_Field = combine_fields_lazily(
[String1.Field, Number1.Field, Variable1.Field], name="Term"
)


class Term_Tuple(Predicate):
Expand All @@ -114,6 +133,7 @@ class Term_Tuple(Predicate):
class Function(Predicate):
"""Note: we represent constants as a Function with an empty term
tuple (i.e. no term_tuple fact with a matching identifier"""

id = Identifier_Field
name = ConstantField
arguments = Term_Tuple1.Field
Expand All @@ -137,16 +157,17 @@ class BinaryOperator(str, enum.Enum):
XOr = "^" # bitwise exclusive or


binary_operator_cl2ast = {BinaryOperator[op.name]: ast.BinaryOperator[op.name]
for op in ast.BinaryOperator}
binary_operator_cl2ast = {
BinaryOperator[op.name]: ast.BinaryOperator[op.name] for op in ast.BinaryOperator
}
binary_operator_ast2cl = {v: k for k, v in binary_operator_cl2ast.items()}


class Binary_Operation(Predicate):
id = Identifier_Field
operator = define_enum_field(parent_field=StringField,
enum_class=BinaryOperator,
name="OperatorField")
operator = define_enum_field(
parent_field=StringField, enum_class=BinaryOperator, name="OperatorField"
)
left = Term_Field
right = Term_Field

Expand Down Expand Up @@ -186,9 +207,7 @@ class Sign(str, enum.Enum):

class Literal(Predicate):
id = Identifier_Field
sig = define_enum_field(parent_field=StringField,
enum_class=Sign,
name="SignField")
sig = define_enum_field(parent_field=StringField, enum_class=Sign, name="SignField")
atom = Atom1.Field


Expand Down Expand Up @@ -216,8 +235,9 @@ class Rule(Predicate):
# note that clingo's parser actually allows arbitrary constant as the external_type
# argument of External, but any other value than true or false results in the external
# statement having no effect
ExternalTypeField = refine_field(ConstantField, ["true", "false"],
name="ExternalTypeField")
ExternalTypeField = refine_field(
ConstantField, ["true", "false"], name="ExternalTypeField"
)


class External(Predicate):
Expand Down Expand Up @@ -285,7 +305,7 @@ class Program(Predicate):
Statement_Tuple1,
Constant_Tuple,
Constant_Tuple1,
Program
Program,
]

AST_Predicates = [
Expand Down Expand Up @@ -315,7 +335,7 @@ class Program(Predicate):
Statement_Tuple1,
Constant_Tuple,
Constant_Tuple1,
Program
Program,
]

AST_Fact = Union[
Expand All @@ -332,7 +352,7 @@ class Program(Predicate):
External,
Statement_Tuple,
Constant_Tuple,
Program
Program,
]

AST_Facts = [
Expand All @@ -349,7 +369,7 @@ class Program(Predicate):
External,
Statement_Tuple,
Constant_Tuple,
Program
Program,
]

# Predicates for AST transformation
Expand Down
Loading

0 comments on commit 0cd4b75

Please sign in to comment.