Skip to content

Commit

Permalink
Typing fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
namcsi committed Jan 10, 2024
1 parent c0cbdf3 commit 867a346
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 26 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ line-length = 88

[tool.mypy]
strict = true
exclude = ["tests/*"]
exclude = ["tests/*", "noxfile.py", "init.py"]

[tool.pylsp-mypy]
enabled = true
live_mode = true
strict = true
exclude = ["tests/*"]
exclude = ["tests/*", "noxfile.py", "init.py"]

6 changes: 3 additions & 3 deletions src/renopro/enum_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import inspect
from types import new_class
from typing import Type, TypeVar
from typing import Type, TypeVar, Any

from clorm import BaseField, ConstantField, StringField

Expand Down Expand Up @@ -49,15 +49,15 @@ class IO(str,Enum):

values = set(i.value for i in enum_class)

def _pytocl(py):
def _pytocl(py: enum.Enum) -> Any:
val = py.value
if val not in values:
raise ValueError(
f"'{val}' is not a valid value of enum class '{enum_class.__name__}'"
)
return val

def body(ns):
def body(ns: dict[str, Any]) -> None:
ns.update({"pytocl": _pytocl, "cltopy": enum_class, "enum": enum_class})

return new_class(subclass_name, (parent_field,), {}, body)
Expand Down
18 changes: 8 additions & 10 deletions src/renopro/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import re
from itertools import count
from types import new_class
from typing import TYPE_CHECKING, Any, Protocol, Union, cast, dataclass_transform
from typing import TYPE_CHECKING, Any, Protocol, Sequence, Type, Union, cast

from clingo import Symbol
from clorm import (
BaseField,
ComplexTerm,
Expand All @@ -14,7 +15,6 @@
Predicate,
RawField,
StringField,
field,
refine_field,
)
from clorm.orm.core import _PredicateMeta
Expand All @@ -28,7 +28,6 @@
TheoryOperatorTypeField,
TheorySequenceTypeField,
UnaryOperatorField,
__dataclass_transform__
)

id_count = count()
Expand All @@ -55,18 +54,18 @@ def combine_fields(

fields = list(fields)

def _pytocl(value):
def _pytocl(value: Any) -> Symbol:
for f in fields:
try:
return f.pytocl(value)
return f.pytocl(value) # type: ignore
except (TypeError, ValueError, AttributeError):
pass
raise TypeError(f"No combined pytocl() match for value {value}.")

def _cltopy(symbol):
def _cltopy(symbol: Symbol) -> Any:
for f in fields:
try:
return f.cltopy(symbol)
return f.cltopy(symbol) # type: ignore
except (TypeError, ValueError):
pass
raise TypeError(
Expand All @@ -76,7 +75,7 @@ def _cltopy(symbol):
)
)

def body(ns):
def body(ns: dict[str, Any]) -> None:
ns.update({"fields": fields, "pytocl": _pytocl, "cltopy": _cltopy})

return new_class(subclass_name, (BaseField,), {}, body)
Expand All @@ -91,7 +90,6 @@ def body(ns):
# which are used to identify child AST facts


@__dataclass_transform__(field_specifiers=(field,))
class _AstPredicateMeta(_PredicateMeta):
def __new__(
mcs,
Expand All @@ -117,7 +115,7 @@ def id_body(ns: dict[str, Any]) -> None:
) # type: ignore
cls.unary = unary
cls.unary.non_unary = cls
return cast(_AstPredicateMeta, cls)
return cls


class IdentifierPredicate(Predicate):
Expand Down
31 changes: 20 additions & 11 deletions src/renopro/rast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import singledispatchmethod
from itertools import count
from pathlib import Path
from types import TracebackType
from typing import (
Any,
Callable,
Expand All @@ -18,6 +19,8 @@
Type,
Union,
overload,
cast,
List
)

from clingo import Control, ast, symbol
Expand Down Expand Up @@ -74,17 +77,22 @@ class TransformationError(Exception):
error or is unsatisfiable."""


class TryUnify(AbstractContextManager):
class TryUnify(AbstractContextManager): # type: ignore
"""Context manager to try some operation that requires unification
of some set of ast facts. Enhance error message if unification fails.
"""

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if exc_type is UnifierNoMatchError:
self.handle_unify_error(exc_value)
self.handle_unify_error(cast(UnifierNoMatchError, exc_value))

@staticmethod
def handle_unify_error(error):
def handle_unify_error(error: UnifierNoMatchError) -> None:
"""Enhance UnifierNoMatchError with some more
useful error messages to help debug the reason unification failed.
Expand All @@ -109,7 +117,7 @@ def handle_unify_error(error):
)
raise UnifierNoMatchError(
inspect.cleandoc(msg), unmatched, error.predicates
) from None
) from None # type: ignore
for idx, arg in enumerate(unmatched.arguments):
# This is very hacky. Should ask Dave for a better
# solution, if there is one.
Expand All @@ -128,7 +136,7 @@ def handle_unify_error(error):
'{arg_field_str}'."""
raise UnifierNoMatchError(
inspect.cleandoc(msg), unmatched, (candidate,)
) from None
) from None # type: ignore
raise RuntimeError("Code should be unreachable") # nocoverage


Expand Down Expand Up @@ -156,7 +164,7 @@ class ReifiedAST:

def __init__(self, reify_location: bool = False):
self._reified = FactBase()
self._program_ast: Sequence[AST] = []
self._program_ast: List[AST] = []
self._current_statement: Tuple[int, int] = (0, 0)
self._tuple_pos: Iterator[int] = count()
self._init_overrides()
Expand Down Expand Up @@ -213,7 +221,7 @@ def reify_files(self, files: Sequence[Path]) -> None:
parse_files(files_str, self.program_ast.append)
self.reify_ast(self._program_ast)

def reify_ast(self, asts: Sequence[AST]) -> None:
def reify_ast(self, asts: List[AST]) -> None:
"""Reify input sequence of AST nodes, adding reified facts to
the internal factbase."""
self._program_ast = asts
Expand All @@ -226,7 +234,7 @@ def program_string(self) -> str:
return "\n".join([str(statement) for statement in self._program_ast])

@property
def program_ast(self) -> Sequence[AST]:
def program_ast(self) -> List[AST]:
"""AST nodes attained via reflection of AST facts."""
return self._program_ast

Expand Down Expand Up @@ -421,7 +429,7 @@ def _reify_location(
)
self._reified.add(preds.Location(id_term.id, begin, end))

def _reify_attr(self, annotation: Type, attr: NodeAttr, field: BaseField) -> Any:
def _reify_attr(self, annotation: Type[NodeAttr], attr: NodeAttr, field: BaseField) -> Any:
"""Reify an AST node's attribute attr based on the type hint
for the respective argument in the AST node's constructor.
This default behavior is overridden in certain cases; see reify_node."""
Expand Down Expand Up @@ -535,7 +543,8 @@ def _reify_body_literals(
self._reified.add(reified_body_lits)
return body_lits1

def _reify_function(self, node):
def _reify_function(self, node: AST) -> preds.IdentifierPredicate:
pred: Type[preds.Function] | Type[preds.ExternalFunction]
if node.external == 0:
pred = preds.Function
func1 = pred.unary()
Expand Down

0 comments on commit 867a346

Please sign in to comment.