From 867a3469e5f0b0941d6f3e09987452f6b9b1d09a Mon Sep 17 00:00:00 2001 From: Amade Nemes Date: Wed, 10 Jan 2024 08:46:39 +0100 Subject: [PATCH] Typing fixes. --- pyproject.toml | 4 ++-- src/renopro/enum_fields.py | 6 +++--- src/renopro/predicates.py | 18 ++++++++---------- src/renopro/rast.py | 31 ++++++++++++++++++++----------- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 11af4a3..104c3c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,11 @@ line-length = 88 [tool.mypy] strict = true -exclude = ["tests/*"] +exclude = ["tests/*", "noxfile.py", "init.py"] [tool.pylsp-mypy] enabled = true live_mode = true strict = true -exclude = ["tests/*"] +exclude = ["tests/*", "noxfile.py", "init.py"] diff --git a/src/renopro/enum_fields.py b/src/renopro/enum_fields.py index c4d1a62..be39081 100644 --- a/src/renopro/enum_fields.py +++ b/src/renopro/enum_fields.py @@ -2,7 +2,7 @@ import enum import inspect from types import new_class -from typing import Type, TypeVar +from typing import Type, TypeVar, Any from clorm import BaseField, ConstantField, StringField @@ -49,7 +49,7 @@ class IO(str,Enum): values = set(i.value for i in enum_class) - def _pytocl(py): + def _pytocl(py: enum.Enum) -> Any: val = py.value if val not in values: raise ValueError( @@ -57,7 +57,7 @@ def _pytocl(py): ) return val - def body(ns): + def body(ns: dict[str, Any]) -> None: ns.update({"pytocl": _pytocl, "cltopy": enum_class, "enum": enum_class}) return new_class(subclass_name, (parent_field,), {}, body) diff --git a/src/renopro/predicates.py b/src/renopro/predicates.py index d287c96..a381d07 100644 --- a/src/renopro/predicates.py +++ b/src/renopro/predicates.py @@ -4,8 +4,9 @@ import re from itertools import count from types import new_class -from typing import TYPE_CHECKING, Any, Protocol, Union, cast, dataclass_transform +from typing import TYPE_CHECKING, Any, Protocol, Sequence, Type, Union, cast +from clingo import Symbol from clorm import ( BaseField, ComplexTerm, @@ -14,7 +15,6 @@ Predicate, RawField, StringField, - field, refine_field, ) from clorm.orm.core import _PredicateMeta @@ -28,7 +28,6 @@ TheoryOperatorTypeField, TheorySequenceTypeField, UnaryOperatorField, - __dataclass_transform__ ) id_count = count() @@ -55,18 +54,18 @@ def combine_fields( fields = list(fields) - def _pytocl(value): + def _pytocl(value: Any) -> Symbol: for f in fields: try: - return f.pytocl(value) + return f.pytocl(value) # type: ignore except (TypeError, ValueError, AttributeError): pass raise TypeError(f"No combined pytocl() match for value {value}.") - def _cltopy(symbol): + def _cltopy(symbol: Symbol) -> Any: for f in fields: try: - return f.cltopy(symbol) + return f.cltopy(symbol) # type: ignore except (TypeError, ValueError): pass raise TypeError( @@ -76,7 +75,7 @@ def _cltopy(symbol): ) ) - def body(ns): + def body(ns: dict[str, Any]) -> None: ns.update({"fields": fields, "pytocl": _pytocl, "cltopy": _cltopy}) return new_class(subclass_name, (BaseField,), {}, body) @@ -91,7 +90,6 @@ def body(ns): # which are used to identify child AST facts -@__dataclass_transform__(field_specifiers=(field,)) class _AstPredicateMeta(_PredicateMeta): def __new__( mcs, @@ -117,7 +115,7 @@ def id_body(ns: dict[str, Any]) -> None: ) # type: ignore cls.unary = unary cls.unary.non_unary = cls - return cast(_AstPredicateMeta, cls) + return cls class IdentifierPredicate(Predicate): diff --git a/src/renopro/rast.py b/src/renopro/rast.py index a33531c..252a3be 100644 --- a/src/renopro/rast.py +++ b/src/renopro/rast.py @@ -7,6 +7,7 @@ from functools import singledispatchmethod from itertools import count from pathlib import Path +from types import TracebackType from typing import ( Any, Callable, @@ -18,6 +19,8 @@ Type, Union, overload, + cast, + List ) from clingo import Control, ast, symbol @@ -74,17 +77,22 @@ class TransformationError(Exception): error or is unsatisfiable.""" -class TryUnify(AbstractContextManager): +class TryUnify(AbstractContextManager): # type: ignore """Context manager to try some operation that requires unification of some set of ast facts. Enhance error message if unification fails. """ - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: if exc_type is UnifierNoMatchError: - self.handle_unify_error(exc_value) + self.handle_unify_error(cast(UnifierNoMatchError, exc_value)) @staticmethod - def handle_unify_error(error): + def handle_unify_error(error: UnifierNoMatchError) -> None: """Enhance UnifierNoMatchError with some more useful error messages to help debug the reason unification failed. @@ -109,7 +117,7 @@ def handle_unify_error(error): ) raise UnifierNoMatchError( inspect.cleandoc(msg), unmatched, error.predicates - ) from None + ) from None # type: ignore for idx, arg in enumerate(unmatched.arguments): # This is very hacky. Should ask Dave for a better # solution, if there is one. @@ -128,7 +136,7 @@ def handle_unify_error(error): '{arg_field_str}'.""" raise UnifierNoMatchError( inspect.cleandoc(msg), unmatched, (candidate,) - ) from None + ) from None # type: ignore raise RuntimeError("Code should be unreachable") # nocoverage @@ -156,7 +164,7 @@ class ReifiedAST: def __init__(self, reify_location: bool = False): self._reified = FactBase() - self._program_ast: Sequence[AST] = [] + self._program_ast: List[AST] = [] self._current_statement: Tuple[int, int] = (0, 0) self._tuple_pos: Iterator[int] = count() self._init_overrides() @@ -213,7 +221,7 @@ def reify_files(self, files: Sequence[Path]) -> None: parse_files(files_str, self.program_ast.append) self.reify_ast(self._program_ast) - def reify_ast(self, asts: Sequence[AST]) -> None: + def reify_ast(self, asts: List[AST]) -> None: """Reify input sequence of AST nodes, adding reified facts to the internal factbase.""" self._program_ast = asts @@ -226,7 +234,7 @@ def program_string(self) -> str: return "\n".join([str(statement) for statement in self._program_ast]) @property - def program_ast(self) -> Sequence[AST]: + def program_ast(self) -> List[AST]: """AST nodes attained via reflection of AST facts.""" return self._program_ast @@ -421,7 +429,7 @@ def _reify_location( ) self._reified.add(preds.Location(id_term.id, begin, end)) - def _reify_attr(self, annotation: Type, attr: NodeAttr, field: BaseField) -> Any: + def _reify_attr(self, annotation: Type[NodeAttr], attr: NodeAttr, field: BaseField) -> Any: """Reify an AST node's attribute attr based on the type hint for the respective argument in the AST node's constructor. This default behavior is overridden in certain cases; see reify_node.""" @@ -535,7 +543,8 @@ def _reify_body_literals( self._reified.add(reified_body_lits) return body_lits1 - def _reify_function(self, node): + def _reify_function(self, node: AST) -> preds.IdentifierPredicate: + pred: Type[preds.Function] | Type[preds.ExternalFunction] if node.external == 0: pred = preds.Function func1 = pred.unary()