Skip to content

Commit

Permalink
Wip switching to strict mypy typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
namcsi committed Dec 7, 2023
1 parent 0cef01e commit 32027fb
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 39 deletions.
35 changes: 30 additions & 5 deletions src/renopro/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from itertools import count
from types import new_class
from typing import Union
from typing import TYPE_CHECKING, Any, Protocol, Union, cast, dataclass_transform

from clorm import (
ComplexTerm,
Expand All @@ -12,6 +12,7 @@
Predicate,
RawField,
StringField,
field,
refine_field,
)
from clorm.orm.core import _PredicateMeta
Expand All @@ -25,6 +26,7 @@
TheoryOperatorTypeField,
TheorySequenceTypeField,
UnaryOperatorField,
__dataclass_transform__
)
from renopro.utils.clorm_utils import combine_fields

Expand All @@ -41,12 +43,19 @@
# which are used to identify child AST facts


@__dataclass_transform__(field_specifiers=(field,))
class _AstPredicateMeta(_PredicateMeta):
def __new__(mcs, cls_name, bases, namespace, **kwargs):
def __new__(
mcs,
cls_name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> "_AstPredicateMeta":
pattern = re.compile(r"(?<!^)(?=[A-Z])")
underscore_lower_cls_name = pattern.sub("_", cls_name).lower()

def id_body(ns):
def id_body(ns: dict[str, Any]) -> None:
ns.update({"id": IdentifierField})

unary = new_class(
Expand All @@ -57,15 +66,31 @@ def id_body(ns):
)
cls = super().__new__(
mcs, cls_name, bases, namespace, name=underscore_lower_cls_name, **kwargs
)
) # type: ignore
cls.unary = unary
cls.unary.non_unary = cls
return cls
return cast(_AstPredicateMeta, cls)


class IdentifierPredicate(Predicate):
id = IdentifierField


# define callback protocol to correctly type the default argument
# behaviour of IdentifierField


class IdentifierPredicateConstructor(Protocol):
def __call__(self, id: Any = ..., /) -> IdentifierPredicate:
...


class AstPredicate(Predicate, metaclass=_AstPredicateMeta):
"""A predicate representing an AST node."""

if TYPE_CHECKING:
unary: IdentifierPredicateConstructor


class Position(ComplexTerm):
"""Complex field representing a position in a text file."""
Expand Down
34 changes: 22 additions & 12 deletions src/renopro/rast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from itertools import count
from pathlib import Path
from typing import (
Any,
Callable,
Iterator,
Literal,
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self, reify_location: bool = False):
self._reified = FactBase()
self._program_ast: Sequence[AST] = []
self._current_statement: Tuple[int, int] = (0, 0)
self._tuple_pos: count = count()
self._tuple_pos: Iterator[int] = count()
self._init_overrides()
self.reify_location = reify_location

Expand Down Expand Up @@ -157,7 +158,7 @@ def reified_string_doc(self) -> str: # nocoverage
"""
return self._reified.asp_str(commented=True)

def _init_overrides(self):
def _init_overrides(self) -> None:
"""Initialize override functions that change the default
behavior when reifying or reflecting"""
self._reify_overrides = {
Expand Down Expand Up @@ -280,7 +281,7 @@ def _reify_ast_sequence(
self,
ast_seq: Union[ASTSequence, Sequence[Symbol]],
tuple_predicate: Type[preds.AstPredicate],
):
) -> preds.AstPredicate:
"Reify an ast sequence into a list of facts of type tuple_predicate."
tuple_unary_fact = tuple_predicate.unary()
reified_facts = []
Expand Down Expand Up @@ -317,7 +318,9 @@ def _get_type_constructor_from_node(
return symbol_type, symbol_constructor
raise TypeError(f"Node must be of type AST or Symbol, got: {type(node)}")

def _reify_location(self, id_term, location: ast.Location):
def _reify_location(
self, id_term: preds.AstPredicate, location: ast.Location
) -> None:
begin = preds.Position(
location.begin.filename, location.begin.line, location.begin.column
)
Expand All @@ -326,7 +329,7 @@ def _reify_location(self, id_term, location: ast.Location):
)
self._reified.add(preds.Location(id_term.id, begin, end))

def _reify_attr(self, annotation: Type, attr: NodeAttr, field: BaseField):
def _reify_attr(self, annotation: Type, attr: NodeAttr, field: BaseField) -> Any:
"""Reify an AST node's attribute attr based on the type hint
for the respective argument in the AST node's constructor.
This default behavior is overridden in certain cases; see reify_node."""
Expand Down Expand Up @@ -403,7 +406,9 @@ def reify_node(
self._current_statement = (id_, position + 1)
return id_term

def _reify_theory_operators(self, operators: Sequence[str]):
def _reify_theory_operators(
self, operators: Sequence[str]
) -> preds.TheoryOperators.unary:
operators1 = preds.TheoryOperators.unary()
reified_operators = [
preds.TheoryOperators(id=operators1.id, position=p, operator=op)
Expand All @@ -412,7 +417,9 @@ def _reify_theory_operators(self, operators: Sequence[str]):
self._reified.add(reified_operators)
return operators1

def _reify_body_literals(self, nodes: Sequence[ast.AST]):
def _reify_body_literals(
self, nodes: Sequence[ast.AST]
) -> preds.BodyLiterals.unary:
body_lits1 = preds.BodyLiterals.unary()
reified_body_lits = []
for pos, lit in enumerate(nodes, start=0):
Expand Down Expand Up @@ -463,7 +470,7 @@ def _reify_id(self, node):
)
return const1

def _reify_bool(self, boolean: int):
def _reify_bool(self, boolean: int) -> str:
return "true" if boolean == 1 else "false"

def _get_children(
Expand Down Expand Up @@ -600,14 +607,17 @@ def _node_constructor_from_pred(ast_pred: Type[preds.AstPred]) -> NodeConstructo
) # nocoverage

@singledispatchmethod
def reflect_fact(self, fact: preds.AstPred): # nocoverage
def reflect_fact(self, fact: preds.AstPred) -> AST: # nocoverage
"""Convert the input AST element's reified fact representation
back into a the corresponding member of clingo's abstract
syntax tree, recursively reflecting all child facts.
"""
predicate = type(fact)
if pred_override_func := self._reflect_overrides["pred"].get(predicate):
pred_override_func: Optional[Callable[[Any], AST]] = self._reflect_overrides[
"pred"
].get(predicate)
if pred_override_func:
return pred_override_func(fact)
node_constructor = self._node_constructor_from_pred(predicate)

Expand Down Expand Up @@ -654,7 +664,7 @@ def reflect_fact(self, fact: preds.AstPred): # nocoverage
reflected_node = node_constructor(**kwargs_dict)
return reflected_node

def _reflect_bool(self, boolean: Literal["true", "false"]):
def _reflect_bool(self, boolean: Literal["true", "false"]) -> int:
return 1 if boolean == "true" else 0

def _reflect_program(self, program: preds.Program) -> Sequence[AST]:
Expand All @@ -673,7 +683,7 @@ def _reflect_program(self, program: preds.Program) -> Sequence[AST]:
subprogram.extend(statement_nodes)
return subprogram

def reflect(self):
def reflect(self) -> None:
"""Convert stored reified ast facts into a (sequence of) AST
node(s), and it's string representation.
Expand Down
38 changes: 22 additions & 16 deletions src/renopro/utils/clorm_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"Clorm related utility functions."
"""Clorm related utility functions."""
import enum
import inspect
import re
from contextlib import AbstractContextManager
from types import new_class
from typing import Sequence, Type, TypeVar
from types import TracebackType, new_class
from typing import Any, Sequence, Type, TypeVar, cast

from clingo import Symbol
from clorm import BaseField, UnifierNoMatchError
from thefuzz import process # type: ignore

Expand Down Expand Up @@ -33,18 +34,18 @@ def combine_fields(

fields = list(fields)

def _pytocl(value):
def _pytocl(value: Any) -> Symbol:
for f in fields:
try:
return f.pytocl(value)
return f.pytocl(value) # type: ignore
except (TypeError, ValueError, AttributeError):
pass
raise TypeError(f"No combined pytocl() match for value {value}.")

def _cltopy(symbol):
def _cltopy(symbol: Symbol) -> Any:
for f in fields:
try:
return f.cltopy(symbol)
return f.cltopy(symbol) # type: ignore
except (TypeError, ValueError):
pass
raise TypeError(
Expand All @@ -54,7 +55,7 @@ def _cltopy(symbol):
)
)

def body(ns):
def body(ns: dict[str, Any]) -> None:
ns.update({"fields": fields, "pytocl": _pytocl, "cltopy": _cltopy})

return new_class(subclass_name, (BaseField,), {}, body)
Expand Down Expand Up @@ -102,15 +103,15 @@ class IO(str,Enum):

values = set(i.value for i in enum_class)

def _pytocl(py):
def _pytocl(py: enum.Enum) -> Any:
val = py.value
if val not in values:
raise ValueError(
f"'{val}' is not a valid value of enum class '{enum_class.__name__}'"
)
return val

def body(ns):
def body(ns: dict[str, Any]) -> None:
ns.update({"pytocl": _pytocl, "cltopy": enum_class, "enum": enum_class})

return new_class(subclass_name, (parent_field,), {}, body)
Expand Down Expand Up @@ -148,17 +149,22 @@ class ChildrenQueryError(Exception):
"""


class TryUnify(AbstractContextManager):
class TryUnify(AbstractContextManager): # type: ignore
"""Context manager to try some operation that requires unification
of some set of ast facts. Enhance error message if unification fails.
"""

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if exc_type is UnifierNoMatchError:
self.handle_unify_error(exc_value)
self.handle_unify_error(cast(UnifierNoMatchError, exc_value))

@staticmethod
def handle_unify_error(error):
def handle_unify_error(error: UnifierNoMatchError) -> None:
"""Enhance UnifierNoMatchError with some more
useful error messages to help debug the reason unification failed.
Expand All @@ -183,7 +189,7 @@ def handle_unify_error(error):
)
raise UnifierNoMatchError(
inspect.cleandoc(msg), unmatched, error.predicates
) from None
) from None # type: ignore
for idx, arg in enumerate(unmatched.arguments):
# This is very hacky. Should ask Dave for a better
# solution, if there is one.
Expand All @@ -202,5 +208,5 @@ def handle_unify_error(error):
'{arg_field_str}'."""
raise UnifierNoMatchError(
inspect.cleandoc(msg), unmatched, (candidate,)
) from None
) from None # type: ignore
raise RuntimeError("Code should be unreachable") # nocoverage
8 changes: 4 additions & 4 deletions src/renopro/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ class SingleLevelFilter(logging.Filter):
Filter levels.
"""

def __init__(self, passlevel, reject):
def __init__(self, passlevel: int, reject: bool):
# pylint: disable=super-init-not-called
self.passlevel = passlevel
self.reject = reject

def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
if self.reject:
return record.levelno != self.passlevel # nocoverage

return record.levelno == self.passlevel


def setup_logger(name, level):
def setup_logger(name: str, level: int) -> logging.Logger:
"""
Setup logger.
"""
Expand All @@ -45,7 +45,7 @@ def setup_logger(name, level):
logger.setLevel(level)
log_message_str = "{}%(levelname)s:{} - %(message)s{}"

def set_handler(level, color):
def set_handler(level: int, color: str) -> None:
handler = logging.StreamHandler(sys.stderr)
handler.addFilter(SingleLevelFilter(level, False))
handler.setLevel(level)
Expand Down
4 changes: 2 additions & 2 deletions src/renopro/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from argparse import REMAINDER, ArgumentParser
from pathlib import Path
from textwrap import dedent
from typing import Any, cast
from typing import Any, Optional, cast

__all__ = ["get_parser"]

Expand All @@ -33,7 +33,7 @@ def get_parser() -> ArgumentParser:
("debug", logging.DEBUG),
]

def get(levels, name):
def get(levels: list[tuple[str, int]], name: Any) -> Optional[Any]:
for key, val in levels:
if key == name:
return val
Expand Down

0 comments on commit 32027fb

Please sign in to comment.